diff --git a/neqo-bin/src/lib.rs b/neqo-bin/src/lib.rs index 7b229a80b1..0036da4c76 100644 --- a/neqo-bin/src/lib.rs +++ b/neqo-bin/src/lib.rs @@ -118,6 +118,10 @@ pub struct QuicParameters { /// Whether to disable pacing. pub no_pacing: bool, + #[arg(long)] + /// Whether to disable path MTU discovery. + pub no_pmtud: bool, + #[arg(name = "preferred-address-v4", long)] /// An IPv4 address for the server preferred address. pub preferred_address_v4: Option, @@ -137,6 +141,7 @@ impl Default for QuicParameters { idle_timeout: 30, congestion_control: CongestionControlAlgorithm::NewReno, no_pacing: false, + no_pmtud: false, preferred_address_v4: None, preferred_address_v6: None, } @@ -203,7 +208,8 @@ impl QuicParameters { .max_streams(StreamType::UniDi, self.max_streams_uni) .idle_timeout(Duration::from_secs(self.idle_timeout)) .cc_algorithm(self.congestion_control) - .pacing(!self.no_pacing); + .pacing(!self.no_pacing) + .pmtud(!self.no_pmtud); if let Some(&first) = self.quic_version.first() { let all = if self.quic_version[1..].contains(&first) { diff --git a/neqo-http3/src/features/extended_connect/tests/webtransport/mod.rs b/neqo-http3/src/features/extended_connect/tests/webtransport/mod.rs index 75ec1c6909..8e9da9cc14 100644 --- a/neqo-http3/src/features/extended_connect/tests/webtransport/mod.rs +++ b/neqo-http3/src/features/extended_connect/tests/webtransport/mod.rs @@ -12,7 +12,7 @@ use std::{cell::RefCell, rc::Rc, time::Duration}; use neqo_common::event::Provider; use neqo_crypto::AuthenticationStatus; -use neqo_transport::{ConnectionParameters, StreamId, StreamType, MIN_INITIAL_PACKET_SIZE}; +use neqo_transport::{ConnectionParameters, Pmtud, StreamId, StreamType}; use test_fixture::{ anti_replay, fixture_init, now, CountingConnectionIdGenerator, DEFAULT_ADDR, DEFAULT_ALPN_H3, DEFAULT_KEYS, DEFAULT_SERVER_NAME, @@ -25,7 +25,8 @@ use crate::{ WebTransportServerEvent, WebTransportSessionAcceptAction, }; -const DATAGRAM_SIZE: u64 = MIN_INITIAL_PACKET_SIZE as u64; +// Leave space for large QUIC header. +const DATAGRAM_SIZE: u64 = Pmtud::default_plpmtu(DEFAULT_ADDR.ip()) as u64 - 40; pub fn wt_default_parameters() -> Http3Parameters { Http3Parameters::default() diff --git a/neqo-transport/Cargo.toml b/neqo-transport/Cargo.toml index c72df5e7f1..b26f74ffc8 100644 --- a/neqo-transport/Cargo.toml +++ b/neqo-transport/Cargo.toml @@ -24,6 +24,7 @@ neqo-common = { path = "../neqo-common" } neqo-crypto = { path = "../neqo-crypto" } qlog = { workspace = true } smallvec = { version = "1.11", default-features = false } +static_assertions = { version = "1.1", default-features = false } [dev-dependencies] criterion = { version = "0.5", default-features = false, features = ["html_reports"] } diff --git a/neqo-transport/src/cc/classic_cc.rs b/neqo-transport/src/cc/classic_cc.rs index d825fd28b1..f7beb3ad61 100644 --- a/neqo-transport/src/cc/classic_cc.rs +++ b/neqo-transport/src/cc/classic_cc.rs @@ -14,23 +14,18 @@ use std::{ use super::CongestionControl; use crate::{ - cc::MAX_DATAGRAM_SIZE, packet::PacketNumber, qlog::{self, QlogMetric}, recovery::SentPacket, rtt::RttEstimate, sender::PACING_BURST_SIZE, + Pmtud, }; #[rustfmt::skip] // to keep `::` and thus prevent conflict with `crate::qlog` use ::qlog::events::{quic::CongestionStateUpdated, EventData}; use neqo_common::{const_max, const_min, qdebug, qinfo, qlog::NeqoQlog, qtrace}; pub const CWND_INITIAL_PKTS: usize = 10; -pub const CWND_INITIAL: usize = const_min( - CWND_INITIAL_PKTS * MAX_DATAGRAM_SIZE, - const_max(2 * MAX_DATAGRAM_SIZE, 14720), -); -pub const CWND_MIN: usize = MAX_DATAGRAM_SIZE * 2; const PERSISTENT_CONG_THRESH: u32 = 3; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -90,13 +85,19 @@ pub trait WindowAdjustment: Display + Debug { curr_cwnd: usize, new_acked_bytes: usize, min_rtt: Duration, + max_datagram_size: usize, now: Instant, ) -> usize; /// This function is called when a congestion event has beed detected and it /// returns new (decreased) values of `curr_cwnd` and `acked_bytes`. /// This value can be very small; the calling code is responsible for ensuring that the /// congestion window doesn't drop below the minimum of `CWND_MIN`. - fn reduce_cwnd(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize); + fn reduce_cwnd( + &mut self, + curr_cwnd: usize, + acked_bytes: usize, + max_datagram_size: usize, + ) -> (usize, usize); /// Cubic needs this signal to reset its epoch. fn on_app_limited(&mut self); #[cfg(test)] @@ -122,10 +123,16 @@ pub struct ClassicCongestionControl { /// /// [1]: https://datatracker.ietf.org/doc/html/rfc9002#section-7.8 first_app_limited: PacketNumber, - + pmtud: Pmtud, qlog: NeqoQlog, } +impl ClassicCongestionControl { + pub const fn max_datagram_size(&self) -> usize { + self.pmtud.plpmtu() + } +} + impl Display for ClassicCongestionControl { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( @@ -159,6 +166,25 @@ impl CongestionControl for ClassicCongestionControl { self.congestion_window.saturating_sub(self.bytes_in_flight) } + #[must_use] + fn cwnd_min(&self) -> usize { + self.max_datagram_size() * 2 + } + + #[cfg(test)] + #[must_use] + fn cwnd_initial(&self) -> usize { + cwnd_initial(self.pmtud.plpmtu()) + } + + fn pmtud(&self) -> &Pmtud { + &self.pmtud + } + + fn pmtud_mut(&mut self) -> &mut Pmtud { + &mut self.pmtud + } + // Multi-packet version of OnPacketAckedCC fn on_packets_acked(&mut self, acked_pkts: &[SentPacket], rtt_est: &RttEstimate, now: Instant) { let mut is_app_limited = true; @@ -224,6 +250,7 @@ impl CongestionControl for ClassicCongestionControl { self.congestion_window, new_acked, rtt_est.minimum(), + self.max_datagram_size(), now, ); debug_assert!(bytes_for_increase > 0); @@ -231,12 +258,12 @@ impl CongestionControl for ClassicCongestionControl { // If we have sudden increase in allowed rate we actually increase cwnd gently. if self.acked_bytes >= bytes_for_increase { self.acked_bytes = 0; - self.congestion_window += MAX_DATAGRAM_SIZE; + self.congestion_window += self.max_datagram_size(); } self.acked_bytes += new_acked; if self.acked_bytes >= bytes_for_increase { self.acked_bytes -= bytes_for_increase; - self.congestion_window += MAX_DATAGRAM_SIZE; // or is this the current MTU? + self.congestion_window += self.max_datagram_size(); // or is this the current MTU? } // The number of bytes we require can go down over time with Cubic. // That might result in an excessive rate of increase, so limit the number of unused @@ -281,12 +308,24 @@ impl CongestionControl for ClassicCongestionControl { &[QlogMetric::BytesInFlight(self.bytes_in_flight)], ); - let congestion = self.on_congestion_event(lost_packets.last().unwrap()); + let is_pmtud_probe = self.pmtud.is_probe_filter(); + let mut lost_packets = lost_packets + .iter() + .filter(|pkt| !is_pmtud_probe(pkt)) + .rev() + .peekable(); + + // Lost PMTUD probes do not elicit a congestion control reaction. + let Some(last_lost_packet) = lost_packets.peek() else { + return false; + }; + + let congestion = self.on_congestion_event(last_lost_packet); let persistent_congestion = self.detect_persistent_congestion( first_rtt_sample_time, prev_largest_acked_sent, pto, - lost_packets, + lost_packets.rev(), ); qdebug!( "on_packets_lost this={:p}, bytes_in_flight={}, cwnd={}, state={:?}", @@ -363,18 +402,23 @@ impl CongestionControl for ClassicCongestionControl { } } +const fn cwnd_initial(mtu: usize) -> usize { + const_min(CWND_INITIAL_PKTS * mtu, const_max(2 * mtu, 14_720)) +} + impl ClassicCongestionControl { - pub fn new(cc_algorithm: T) -> Self { + pub fn new(cc_algorithm: T, pmtud: Pmtud) -> Self { Self { cc_algorithm, state: State::SlowStart, - congestion_window: CWND_INITIAL, + congestion_window: cwnd_initial(pmtud.plpmtu()), bytes_in_flight: 0, acked_bytes: 0, ssthresh: usize::MAX, recovery_start: None, qlog: NeqoQlog::disabled(), first_app_limited: 0, + pmtud, } } @@ -425,12 +469,12 @@ impl ClassicCongestionControl { } } - fn detect_persistent_congestion( + fn detect_persistent_congestion<'a>( &mut self, first_rtt_sample_time: Option, prev_largest_acked_sent: Option, pto: Duration, - lost_packets: &[SentPacket], + lost_packets: impl IntoIterator, ) -> bool { if first_rtt_sample_time.is_none() { return false; @@ -447,7 +491,7 @@ impl ClassicCongestionControl { // as we might not have sent PTO packets soon enough after those. let cutoff = max(first_rtt_sample_time, prev_largest_acked_sent); for p in lost_packets - .iter() + .into_iter() .skip_while(|p| Some(p.time_sent()) < cutoff) { if p.pn() != last_pn + 1 { @@ -466,7 +510,7 @@ impl ClassicCongestionControl { .expect("time is monotonic"); if elapsed > pc_period { qinfo!([self], "persistent congestion"); - self.congestion_window = CWND_MIN; + self.congestion_window = self.cwnd_min(); self.acked_bytes = 0; self.set_state(State::PersistentCongestion); qlog::metrics_updated( @@ -502,10 +546,12 @@ impl ClassicCongestionControl { return false; } - let (cwnd, acked_bytes) = self - .cc_algorithm - .reduce_cwnd(self.congestion_window, self.acked_bytes); - self.congestion_window = max(cwnd, CWND_MIN); + let (cwnd, acked_bytes) = self.cc_algorithm.reduce_cwnd( + self.congestion_window, + self.acked_bytes, + self.max_datagram_size(), + ); + self.congestion_window = max(cwnd, self.cwnd_min()); self.acked_bytes = acked_bytes; self.ssthresh = self.congestion_window; qdebug!( @@ -537,33 +583,37 @@ impl ClassicCongestionControl { } else { // We're not limited if the in-flight data is within a single burst of the // congestion window. - (self.bytes_in_flight + MAX_DATAGRAM_SIZE * PACING_BURST_SIZE) < self.congestion_window + (self.bytes_in_flight + self.max_datagram_size() * PACING_BURST_SIZE) + < self.congestion_window } } } #[cfg(test)] mod tests { - use std::time::{Duration, Instant}; + use std::{ + net::{IpAddr, Ipv4Addr}, + time::{Duration, Instant}, + }; use neqo_common::{qinfo, IpTosEcn}; use test_fixture::now; - use super::{ - ClassicCongestionControl, WindowAdjustment, CWND_INITIAL, CWND_MIN, PERSISTENT_CONG_THRESH, - }; + use super::{ClassicCongestionControl, WindowAdjustment, PERSISTENT_CONG_THRESH}; use crate::{ cc::{ classic_cc::State, cubic::{Cubic, CUBIC_BETA_USIZE_DIVIDEND, CUBIC_BETA_USIZE_DIVISOR}, new_reno::NewReno, - CongestionControl, CongestionControlAlgorithm, CWND_INITIAL_PKTS, MAX_DATAGRAM_SIZE, + CongestionControl, CongestionControlAlgorithm, CWND_INITIAL_PKTS, }, packet::{PacketNumber, PacketType}, recovery::SentPacket, rtt::RttEstimate, + Pmtud, }; + const IP_ADDR: IpAddr = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)); const PTO: Duration = Duration::from_millis(100); const RTT: Duration = Duration::from_millis(98); const RTT_ESTIMATE: RttEstimate = RttEstimate::from_duration(Duration::from_millis(98)); @@ -577,13 +627,13 @@ mod tests { const PC: Duration = Duration::from_nanos(100_000_000 * (PERSISTENT_CONG_THRESH as u64) + 1); fn cwnd_is_default(cc: &ClassicCongestionControl) { - assert_eq!(cc.cwnd(), CWND_INITIAL); + assert_eq!(cc.cwnd(), cc.cwnd_initial()); assert_eq!(cc.ssthresh(), usize::MAX); } fn cwnd_is_halved(cc: &ClassicCongestionControl) { - assert_eq!(cc.cwnd(), CWND_INITIAL / 2); - assert_eq!(cc.ssthresh(), CWND_INITIAL / 2); + assert_eq!(cc.cwnd(), cc.cwnd_initial() / 2); + assert_eq!(cc.ssthresh(), cc.cwnd_initial() / 2); } fn lost(pn: PacketNumber, ack_eliciting: bool, t: Duration) -> SentPacket { @@ -600,22 +650,23 @@ mod tests { fn congestion_control(cc: CongestionControlAlgorithm) -> Box { match cc { - CongestionControlAlgorithm::NewReno => { - Box::new(ClassicCongestionControl::new(NewReno::default())) - } - CongestionControlAlgorithm::Cubic => { - Box::new(ClassicCongestionControl::new(Cubic::default())) - } + CongestionControlAlgorithm::NewReno => Box::new(ClassicCongestionControl::new( + NewReno::default(), + Pmtud::new(IP_ADDR), + )), + CongestionControlAlgorithm::Cubic => Box::new(ClassicCongestionControl::new( + Cubic::default(), + Pmtud::new(IP_ADDR), + )), } } fn persistent_congestion_by_algorithm( - cc_alg: CongestionControlAlgorithm, + mut cc: Box, reduced_cwnd: usize, lost_packets: &[SentPacket], persistent_expected: bool, ) { - let mut cc = congestion_control(cc_alg); for p in lost_packets { cc.on_packet_sent(p); } @@ -624,7 +675,7 @@ mod tests { let persistent = if cc.cwnd() == reduced_cwnd { false - } else if cc.cwnd() == CWND_MIN { + } else if cc.cwnd() == cc.cwnd_min() { true } else { panic!("unexpected cwnd"); @@ -633,15 +684,15 @@ mod tests { } fn persistent_congestion(lost_packets: &[SentPacket], persistent_expected: bool) { + let cc = congestion_control(CongestionControlAlgorithm::NewReno); + let cwnd_initial = cc.cwnd_initial(); + persistent_congestion_by_algorithm(cc, cwnd_initial / 2, lost_packets, persistent_expected); + + let cc = congestion_control(CongestionControlAlgorithm::Cubic); + let cwnd_initial = cc.cwnd_initial(); persistent_congestion_by_algorithm( - CongestionControlAlgorithm::NewReno, - CWND_INITIAL / 2, - lost_packets, - persistent_expected, - ); - persistent_congestion_by_algorithm( - CongestionControlAlgorithm::Cubic, - CWND_INITIAL * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR, + cc, + cwnd_initial * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR, lost_packets, persistent_expected, ); @@ -823,19 +874,19 @@ mod tests { rtt_time: u32, lost: &[SentPacket], ) -> bool { - assert_eq!(cc.cwnd(), CWND_INITIAL); + assert_eq!(cc.cwnd(), cc.cwnd_initial()); let last_ack = Some(by_pto(last_ack)); let rtt_time = Some(by_pto(rtt_time)); // Persistent congestion is never declared if the RTT time is `None`. - cc.detect_persistent_congestion(None, None, PTO, lost); - assert_eq!(cc.cwnd(), CWND_INITIAL); - cc.detect_persistent_congestion(None, last_ack, PTO, lost); - assert_eq!(cc.cwnd(), CWND_INITIAL); + cc.detect_persistent_congestion(None, None, PTO, lost.iter()); + assert_eq!(cc.cwnd(), cc.cwnd_initial()); + cc.detect_persistent_congestion(None, last_ack, PTO, lost.iter()); + assert_eq!(cc.cwnd(), cc.cwnd_initial()); - cc.detect_persistent_congestion(rtt_time, last_ack, PTO, lost); - cc.cwnd() == CWND_MIN + cc.detect_persistent_congestion(rtt_time, last_ack, PTO, lost.iter()); + cc.cwnd() == cc.cwnd_min() } /// No persistent congestion can be had if there are no lost packets. @@ -843,13 +894,13 @@ mod tests { fn persistent_congestion_no_lost() { let lost = make_lost(&[]); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(NewReno::default()), + ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)), 0, 0, &lost )); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(Cubic::default()), + ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)), 0, 0, &lost @@ -861,13 +912,13 @@ mod tests { fn persistent_congestion_one_lost() { let lost = make_lost(&[1]); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(NewReno::default()), + ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)), 0, 0, &lost )); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(Cubic::default()), + ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)), 0, 0, &lost @@ -881,37 +932,37 @@ mod tests { // sample are not considered. So 0 is ignored. let lost = make_lost(&[0, PERSISTENT_CONG_THRESH + 1, PERSISTENT_CONG_THRESH + 2]); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(NewReno::default()), + ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)), 1, 1, &lost )); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(NewReno::default()), + ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)), 0, 1, &lost )); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(NewReno::default()), + ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)), 1, 0, &lost )); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(Cubic::default()), + ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)), 1, 1, &lost )); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(Cubic::default()), + ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)), 0, 1, &lost )); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(Cubic::default()), + ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)), 1, 0, &lost @@ -932,13 +983,13 @@ mod tests { lost[0].len(), ); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(NewReno::default()), + ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)), 0, 0, &lost )); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(Cubic::default()), + ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)), 0, 0, &lost @@ -952,13 +1003,13 @@ mod tests { fn persistent_congestion_min() { let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]); assert!(persistent_congestion_by_pto( - ClassicCongestionControl::new(NewReno::default()), + ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)), 0, 0, &lost )); assert!(persistent_congestion_by_pto( - ClassicCongestionControl::new(Cubic::default()), + ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)), 0, 0, &lost @@ -971,17 +1022,17 @@ mod tests { #[test] fn persistent_congestion_no_prev_ack_newreno() { let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]); - let mut cc = ClassicCongestionControl::new(NewReno::default()); - cc.detect_persistent_congestion(Some(by_pto(0)), None, PTO, &lost); - assert_eq!(cc.cwnd(), CWND_MIN); + let mut cc = ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)); + cc.detect_persistent_congestion(Some(by_pto(0)), None, PTO, lost.iter()); + assert_eq!(cc.cwnd(), cc.cwnd_min()); } #[test] fn persistent_congestion_no_prev_ack_cubic() { let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]); - let mut cc = ClassicCongestionControl::new(Cubic::default()); - cc.detect_persistent_congestion(Some(by_pto(0)), None, PTO, &lost); - assert_eq!(cc.cwnd(), CWND_MIN); + let mut cc = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)); + cc.detect_persistent_congestion(Some(by_pto(0)), None, PTO, lost.iter()); + assert_eq!(cc.cwnd(), cc.cwnd_min()); } /// The code asserts on ordering errors. @@ -990,7 +1041,7 @@ mod tests { fn persistent_congestion_unsorted_newreno() { let lost = make_lost(&[PERSISTENT_CONG_THRESH + 2, 1]); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(NewReno::default()), + ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)), 0, 0, &lost @@ -1003,7 +1054,7 @@ mod tests { fn persistent_congestion_unsorted_cubic() { let lost = make_lost(&[PERSISTENT_CONG_THRESH + 2, 1]); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(Cubic::default()), + ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)), 0, 0, &lost @@ -1014,7 +1065,7 @@ mod tests { fn app_limited_slow_start() { const BELOW_APP_LIMIT_PKTS: usize = 5; const ABOVE_APP_LIMIT_PKTS: usize = BELOW_APP_LIMIT_PKTS + 1; - let mut cc = ClassicCongestionControl::new(NewReno::default()); + let mut cc = ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)); let cwnd = cc.congestion_window; let mut now = now(); let mut next_pn = 0; @@ -1031,13 +1082,16 @@ mod tests { now, true, Vec::new(), - MAX_DATAGRAM_SIZE, + cc.max_datagram_size(), ); next_pn += 1; cc.on_packet_sent(&p); pkts.push(p); } - assert_eq!(cc.bytes_in_flight(), packet_burst_size * MAX_DATAGRAM_SIZE); + assert_eq!( + cc.bytes_in_flight(), + packet_burst_size * cc.max_datagram_size() + ); now += RTT; cc.on_packets_acked(&pkts, &RTT_ESTIMATE, now); assert_eq!(cc.bytes_in_flight(), 0); @@ -1056,7 +1110,7 @@ mod tests { now, true, Vec::new(), - MAX_DATAGRAM_SIZE, + cc.max_datagram_size(), ); next_pn += 1; cc.on_packet_sent(&p); @@ -1064,7 +1118,7 @@ mod tests { } assert_eq!( cc.bytes_in_flight(), - ABOVE_APP_LIMIT_PKTS * MAX_DATAGRAM_SIZE + ABOVE_APP_LIMIT_PKTS * cc.max_datagram_size() ); now += RTT; // Check if congestion window gets increased for all packets currently in flight @@ -1073,11 +1127,18 @@ mod tests { assert_eq!( cc.bytes_in_flight(), - (ABOVE_APP_LIMIT_PKTS - i - 1) * MAX_DATAGRAM_SIZE + (ABOVE_APP_LIMIT_PKTS - i - 1) * cc.max_datagram_size() ); // increase acked_bytes with each packet - qinfo!("{} {}", cc.congestion_window, cwnd + i * MAX_DATAGRAM_SIZE); - assert_eq!(cc.congestion_window, cwnd + (i + 1) * MAX_DATAGRAM_SIZE); + qinfo!( + "{} {}", + cc.congestion_window, + cwnd + i * cc.max_datagram_size() + ); + assert_eq!( + cc.congestion_window, + cwnd + (i + 1) * cc.max_datagram_size() + ); assert_eq!(cc.acked_bytes, 0); } } @@ -1088,7 +1149,7 @@ mod tests { const BELOW_APP_LIMIT_PKTS: usize = CWND_PKTS_CA - 2; const ABOVE_APP_LIMIT_PKTS: usize = BELOW_APP_LIMIT_PKTS + 1; - let mut cc = ClassicCongestionControl::new(NewReno::default()); + let mut cc = ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)); let mut now = now(); // Change state to congestion avoidance by introducing loss. @@ -1100,7 +1161,7 @@ mod tests { now, true, Vec::new(), - MAX_DATAGRAM_SIZE, + cc.max_datagram_size(), ); cc.on_packet_sent(&p_lost); cwnd_is_default(&cc); @@ -1114,7 +1175,7 @@ mod tests { now, true, Vec::new(), - MAX_DATAGRAM_SIZE, + cc.max_datagram_size(), ); cc.on_packet_sent(&p_not_lost); now += RTT; @@ -1138,20 +1199,23 @@ mod tests { now, true, Vec::new(), - MAX_DATAGRAM_SIZE, + cc.max_datagram_size(), ); next_pn += 1; cc.on_packet_sent(&p); pkts.push(p); } - assert_eq!(cc.bytes_in_flight(), packet_burst_size * MAX_DATAGRAM_SIZE); + assert_eq!( + cc.bytes_in_flight(), + packet_burst_size * cc.max_datagram_size() + ); now += RTT; for (i, pkt) in pkts.into_iter().enumerate() { cc.on_packets_acked(&[pkt], &RTT_ESTIMATE, now); assert_eq!( cc.bytes_in_flight(), - (packet_burst_size - i - 1) * MAX_DATAGRAM_SIZE + (packet_burst_size - i - 1) * cc.max_datagram_size() ); cwnd_is_halved(&cc); // CWND doesn't grow because we're app limited assert_eq!(cc.acked_bytes, 0); @@ -1169,7 +1233,7 @@ mod tests { now, true, Vec::new(), - MAX_DATAGRAM_SIZE, + cc.max_datagram_size(), ); next_pn += 1; cc.on_packet_sent(&p); @@ -1177,7 +1241,7 @@ mod tests { } assert_eq!( cc.bytes_in_flight(), - ABOVE_APP_LIMIT_PKTS * MAX_DATAGRAM_SIZE + ABOVE_APP_LIMIT_PKTS * cc.max_datagram_size() ); now += RTT; let mut last_acked_bytes = 0; @@ -1187,7 +1251,7 @@ mod tests { assert_eq!( cc.bytes_in_flight(), - (ABOVE_APP_LIMIT_PKTS - i - 1) * MAX_DATAGRAM_SIZE + (ABOVE_APP_LIMIT_PKTS - i - 1) * cc.max_datagram_size() ); // The cwnd doesn't increase, but the acked_bytes do, which will eventually lead to an // increase, once the number of bytes reaches the necessary level @@ -1200,7 +1264,7 @@ mod tests { #[test] fn ecn_ce() { - let mut cc = ClassicCongestionControl::new(NewReno::default()); + let mut cc = ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)); let p_ce = SentPacket::new( PacketType::Short, 1, @@ -1208,7 +1272,7 @@ mod tests { now(), true, Vec::new(), - MAX_DATAGRAM_SIZE, + cc.max_datagram_size(), ); cc.on_packet_sent(&p_ce); cwnd_is_default(&cc); diff --git a/neqo-transport/src/cc/cubic.rs b/neqo-transport/src/cc/cubic.rs index d7ac1068f7..c8c0188ddb 100644 --- a/neqo-transport/src/cc/cubic.rs +++ b/neqo-transport/src/cc/cubic.rs @@ -11,7 +11,7 @@ use std::{ use neqo_common::qtrace; -use crate::cc::{classic_cc::WindowAdjustment, MAX_DATAGRAM_SIZE_F64}; +use crate::cc::classic_cc::WindowAdjustment; // CUBIC congestion control @@ -38,7 +38,7 @@ const EXPONENTIAL_GROWTH_REDUCTION: f64 = 2.0; /// Convert an integer congestion window value into a floating point value. /// This has the effect of reducing larger values to `1<<53`. /// If you have a congestion window that large, something is probably wrong. -fn convert_to_f64(v: usize) -> f64 { +pub fn convert_to_f64(v: usize) -> f64 { let mut f_64 = f64::from(u32::try_from(v >> 21).unwrap_or(u32::MAX)); f_64 *= 2_097_152.0; // f_64 <<= 21 f_64 += f64::from(u32::try_from(v & 0x1f_ffff).unwrap()); @@ -91,17 +91,23 @@ impl Cubic { /// /// From that equation we can calculate K as: /// K = cubic_root((W_max - W_cubic) / C / MSS); - fn calc_k(&self, curr_cwnd: f64) -> f64 { - ((self.w_max - curr_cwnd) / CUBIC_C / MAX_DATAGRAM_SIZE_F64).cbrt() + fn calc_k(&self, curr_cwnd: f64, max_datagram_size: usize) -> f64 { + ((self.w_max - curr_cwnd) / CUBIC_C / convert_to_f64(max_datagram_size)).cbrt() } /// W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1) /// t is relative to the start of the congestion avoidance phase and it is in seconds. - fn w_cubic(&self, t: f64) -> f64 { - (CUBIC_C * (t - self.k).powi(3)).mul_add(MAX_DATAGRAM_SIZE_F64, self.w_max) + fn w_cubic(&self, t: f64, max_datagram_size: usize) -> f64 { + (CUBIC_C * (t - self.k).powi(3)).mul_add(convert_to_f64(max_datagram_size), self.w_max) } - fn start_epoch(&mut self, curr_cwnd_f64: f64, new_acked_f64: f64, now: Instant) { + fn start_epoch( + &mut self, + curr_cwnd_f64: f64, + new_acked_f64: f64, + max_datagram_size: usize, + now: Instant, + ) { self.ca_epoch_start = Some(now); // reset tcp_acked_bytes and estimated_tcp_cwnd; self.tcp_acked_bytes = new_acked_f64; @@ -111,7 +117,7 @@ impl Cubic { self.k = 0.0; } else { self.w_max = self.last_max_cwnd; - self.k = self.calc_k(curr_cwnd_f64); + self.k = self.calc_k(curr_cwnd_f64, max_datagram_size); } qtrace!([self], "New epoch"); } @@ -126,13 +132,14 @@ impl WindowAdjustment for Cubic { curr_cwnd: usize, new_acked_bytes: usize, min_rtt: Duration, + max_datagram_size: usize, now: Instant, ) -> usize { let curr_cwnd_f64 = convert_to_f64(curr_cwnd); let new_acked_f64 = convert_to_f64(new_acked_bytes); if self.ca_epoch_start.is_none() { // This is a start of a new congestion avoidance phase. - self.start_epoch(curr_cwnd_f64, new_acked_f64, now); + self.start_epoch(curr_cwnd_f64, new_acked_f64, max_datagram_size, now); } else { self.tcp_acked_bytes += new_acked_f64; } @@ -149,13 +156,14 @@ impl WindowAdjustment for Cubic { } }) .as_secs_f64(); - let target_cubic = self.w_cubic(time_ca); + let target_cubic = self.w_cubic(time_ca, max_datagram_size); + let max_datagram_size = convert_to_f64(max_datagram_size); let tcp_cnt = self.estimated_tcp_cwnd / CUBIC_ALPHA; let incr = (self.tcp_acked_bytes / tcp_cnt).floor(); if incr > 0.0 { self.tcp_acked_bytes -= incr * tcp_cnt; - self.estimated_tcp_cwnd += incr * MAX_DATAGRAM_SIZE_F64; + self.estimated_tcp_cwnd += incr * max_datagram_size; } let target_cwnd = target_cubic.max(self.estimated_tcp_cwnd); @@ -167,27 +175,32 @@ impl WindowAdjustment for Cubic { // If the target is not significantly higher than the congestion window, require a very // large amount of acknowledged data (effectively block increases). let mut acked_to_increase = - MAX_DATAGRAM_SIZE_F64 * curr_cwnd_f64 / (target_cwnd - curr_cwnd_f64).max(1.0); + max_datagram_size * curr_cwnd_f64 / (target_cwnd - curr_cwnd_f64).max(1.0); // Limit increase to max 1 MSS per EXPONENTIAL_GROWTH_REDUCTION ack packets. // This effectively limits target_cwnd to (1 + 1 / EXPONENTIAL_GROWTH_REDUCTION) cwnd. - acked_to_increase = - acked_to_increase.max(EXPONENTIAL_GROWTH_REDUCTION * MAX_DATAGRAM_SIZE_F64); + acked_to_increase = acked_to_increase.max(EXPONENTIAL_GROWTH_REDUCTION * max_datagram_size); acked_to_increase as usize } - fn reduce_cwnd(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize) { + fn reduce_cwnd( + &mut self, + curr_cwnd: usize, + acked_bytes: usize, + max_datagram_size: usize, + ) -> (usize, usize) { let curr_cwnd_f64 = convert_to_f64(curr_cwnd); // Fast Convergence // If congestion event occurs before the maximum congestion window before the last // congestion event, we reduce the the maximum congestion window and thereby W_max. // check cwnd + MAX_DATAGRAM_SIZE instead of cwnd because with cwnd in bytes, cwnd may be // slightly off. - self.last_max_cwnd = if curr_cwnd_f64 + MAX_DATAGRAM_SIZE_F64 < self.last_max_cwnd { - curr_cwnd_f64 * CUBIC_FAST_CONVERGENCE - } else { - curr_cwnd_f64 - }; + self.last_max_cwnd = + if curr_cwnd_f64 + convert_to_f64(max_datagram_size) < self.last_max_cwnd { + curr_cwnd_f64 * CUBIC_FAST_CONVERGENCE + } else { + curr_cwnd_f64 + }; self.ca_epoch_start = None; ( curr_cwnd * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR, diff --git a/neqo-transport/src/cc/mod.rs b/neqo-transport/src/cc/mod.rs index e85413b491..bbb47c4fd0 100644 --- a/neqo-transport/src/cc/mod.rs +++ b/neqo-transport/src/cc/mod.rs @@ -14,7 +14,7 @@ use std::{ use neqo_common::qlog::NeqoQlog; -use crate::{path::PATH_MTU_V6, recovery::SentPacket, rtt::RttEstimate, Error}; +use crate::{recovery::SentPacket, rtt::RttEstimate, Error, Pmtud}; mod classic_cc; mod cubic; @@ -22,14 +22,10 @@ mod new_reno; pub use classic_cc::ClassicCongestionControl; #[cfg(test)] -pub use classic_cc::{CWND_INITIAL, CWND_INITIAL_PKTS, CWND_MIN}; +pub use classic_cc::CWND_INITIAL_PKTS; pub use cubic::Cubic; pub use new_reno::NewReno; -pub const MAX_DATAGRAM_SIZE: usize = PATH_MTU_V6; -#[allow(clippy::cast_precision_loss)] -pub const MAX_DATAGRAM_SIZE_F64: f64 = MAX_DATAGRAM_SIZE as f64; - pub trait CongestionControl: Display + Debug { fn set_qlog(&mut self, qlog: NeqoQlog); @@ -42,6 +38,19 @@ pub trait CongestionControl: Display + Debug { #[must_use] fn cwnd_avail(&self) -> usize; + #[must_use] + fn cwnd_min(&self) -> usize; + + #[cfg(test)] + #[must_use] + fn cwnd_initial(&self) -> usize; + + #[must_use] + fn pmtud(&self) -> &Pmtud; + + #[must_use] + fn pmtud_mut(&mut self) -> &mut Pmtud; + fn on_packets_acked(&mut self, acked_pkts: &[SentPacket], rtt_est: &RttEstimate, now: Instant); /// Returns true if the congestion window was reduced. diff --git a/neqo-transport/src/cc/new_reno.rs b/neqo-transport/src/cc/new_reno.rs index 47d0d56f37..cba431bdf2 100644 --- a/neqo-transport/src/cc/new_reno.rs +++ b/neqo-transport/src/cc/new_reno.rs @@ -29,12 +29,19 @@ impl WindowAdjustment for NewReno { curr_cwnd: usize, _new_acked_bytes: usize, _min_rtt: Duration, + _max_datagram_size: usize, + _now: Instant, ) -> usize { curr_cwnd } - fn reduce_cwnd(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize) { + fn reduce_cwnd( + &mut self, + curr_cwnd: usize, + acked_bytes: usize, + _max_datagram_size: usize, + ) -> (usize, usize) { (curr_cwnd / 2, acked_bytes / 2) } diff --git a/neqo-transport/src/cc/tests/cubic.rs b/neqo-transport/src/cc/tests/cubic.rs index 9c9cec17cf..4f75c46504 100644 --- a/neqo-transport/src/cc/tests/cubic.rs +++ b/neqo-transport/src/cc/tests/cubic.rs @@ -8,6 +8,7 @@ #![allow(clippy::cast_sign_loss)] use std::{ + net::{IpAddr, Ipv4Addr}, ops::Sub, time::{Duration, Instant}, }; @@ -17,26 +18,29 @@ use test_fixture::now; use crate::{ cc::{ - classic_cc::{ClassicCongestionControl, CWND_INITIAL}, + classic_cc::ClassicCongestionControl, cubic::{ - Cubic, CUBIC_ALPHA, CUBIC_BETA_USIZE_DIVIDEND, CUBIC_BETA_USIZE_DIVISOR, CUBIC_C, - CUBIC_FAST_CONVERGENCE, + convert_to_f64, Cubic, CUBIC_ALPHA, CUBIC_BETA_USIZE_DIVIDEND, + CUBIC_BETA_USIZE_DIVISOR, CUBIC_C, CUBIC_FAST_CONVERGENCE, }, - CongestionControl, MAX_DATAGRAM_SIZE, MAX_DATAGRAM_SIZE_F64, + CongestionControl, }, packet::PacketType, + pmtud::Pmtud, recovery::SentPacket, rtt::RttEstimate, }; +const IP_ADDR: IpAddr = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)); const RTT: Duration = Duration::from_millis(100); -const RTT_ESTIMATE: RttEstimate = RttEstimate::from_duration(Duration::from_millis(100)); -const CWND_INITIAL_F64: f64 = 10.0 * MAX_DATAGRAM_SIZE_F64; -const CWND_INITIAL_10_F64: f64 = 10.0 * CWND_INITIAL_F64; -const CWND_INITIAL_10: usize = 10 * CWND_INITIAL; -const CWND_AFTER_LOSS: usize = CWND_INITIAL * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR; -const CWND_AFTER_LOSS_SLOW_START: usize = - (CWND_INITIAL + MAX_DATAGRAM_SIZE) * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR; + +const fn cwnd_after_loss(cwnd: usize) -> usize { + cwnd * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR +} + +const fn cwnd_after_loss_slow_start(cwnd: usize, mtu: usize) -> usize { + (cwnd + mtu) * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR +} fn fill_cwnd(cc: &mut ClassicCongestionControl, mut next_pn: u64, now: Instant) -> u64 { while cc.bytes_in_flight() < cc.cwnd() { @@ -47,7 +51,7 @@ fn fill_cwnd(cc: &mut ClassicCongestionControl, mut next_pn: u64, now: In now, true, Vec::new(), - MAX_DATAGRAM_SIZE, + cc.max_datagram_size(), ); cc.on_packet_sent(&sent); next_pn += 1; @@ -63,9 +67,9 @@ fn ack_packet(cc: &mut ClassicCongestionControl, pn: u64, now: Instant) { now, true, Vec::new(), - MAX_DATAGRAM_SIZE, + cc.max_datagram_size(), ); - cc.on_packets_acked(&[acked], &RTT_ESTIMATE, now); + cc.on_packets_acked(&[acked], &RttEstimate::from_duration(RTT), now); } fn packet_lost(cc: &mut ClassicCongestionControl, pn: u64) { @@ -77,19 +81,21 @@ fn packet_lost(cc: &mut ClassicCongestionControl, pn: u64) { now(), true, Vec::new(), - MAX_DATAGRAM_SIZE, + cc.max_datagram_size(), ); cc.on_packets_lost(None, None, PTO, &[p_lost]); } -fn expected_tcp_acks(cwnd_rtt_start: usize) -> u64 { - (f64::from(i32::try_from(cwnd_rtt_start).unwrap()) / MAX_DATAGRAM_SIZE_F64 / CUBIC_ALPHA) +fn expected_tcp_acks(cwnd_rtt_start: usize, mtu: usize) -> u64 { + (f64::from(i32::try_from(cwnd_rtt_start).unwrap()) + / f64::from(i32::try_from(mtu).unwrap()) + / CUBIC_ALPHA) .round() as u64 } #[test] fn tcp_phase() { - let mut cubic = ClassicCongestionControl::new(Cubic::default()); + let mut cubic = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)); // change to congestion avoidance state. cubic.set_ssthresh(1); @@ -115,9 +121,10 @@ fn tcp_phase() { for _ in 0..num_tcp_increases { let cwnd_rtt_start = cubic.cwnd(); // Expected acks during a period of RTT / CUBIC_ALPHA. - let acks = expected_tcp_acks(cwnd_rtt_start); + let acks = expected_tcp_acks(cwnd_rtt_start, cubic.max_datagram_size()); // The time between acks if they are ideally paced over a RTT. - let time_increase = RTT / u32::try_from(cwnd_rtt_start / MAX_DATAGRAM_SIZE).unwrap(); + let time_increase = + RTT / u32::try_from(cwnd_rtt_start / cubic.max_datagram_size()).unwrap(); for _ in 0..acks { now += time_increase; @@ -126,7 +133,7 @@ fn tcp_phase() { next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now); } - assert_eq!(cubic.cwnd() - cwnd_rtt_start, MAX_DATAGRAM_SIZE); + assert_eq!(cubic.cwnd() - cwnd_rtt_start, cubic.max_datagram_size()); } // The next increase will be according to the cubic equation. @@ -134,8 +141,8 @@ fn tcp_phase() { let cwnd_rtt_start = cubic.cwnd(); // cwnd_rtt_start has change, therefore calculate new time_increase (the time // between acks if they are ideally paced over a RTT). - let time_increase = RTT / u32::try_from(cwnd_rtt_start / MAX_DATAGRAM_SIZE).unwrap(); - let mut num_acks = 0; // count the number of acks. until cwnd is increased by MAX_DATAGRAM_SIZE. + let time_increase = RTT / u32::try_from(cwnd_rtt_start / cubic.max_datagram_size()).unwrap(); + let mut num_acks = 0; // count the number of acks. until cwnd is increased by cubic.max_datagram_size(). while cwnd_rtt_start == cubic.cwnd() { num_acks += 1; @@ -147,7 +154,7 @@ fn tcp_phase() { // Make sure that the increase is not according to TCP equation, i.e., that it took // less than RTT / CUBIC_ALPHA. - let expected_ack_tcp_increase = expected_tcp_acks(cwnd_rtt_start); + let expected_ack_tcp_increase = expected_tcp_acks(cwnd_rtt_start, cubic.max_datagram_size()); assert!(num_acks < expected_ack_tcp_increase); // This first increase after a TCP phase may be shorter than what it would take by a regular @@ -159,7 +166,8 @@ fn tcp_phase() { let elapsed_time = now - start_time; // calculate new time_increase. - let time_increase = RTT / u32::try_from(cwnd_rtt_start_after_tcp / MAX_DATAGRAM_SIZE).unwrap(); + let time_increase = + RTT / u32::try_from(cwnd_rtt_start_after_tcp / cubic.max_datagram_size()).unwrap(); let mut num_acks2 = 0; // count the number of acks. until cwnd is increased by MAX_DATAGRAM_SIZE. while cwnd_rtt_start_after_tcp == cubic.cwnd() { @@ -170,7 +178,8 @@ fn tcp_phase() { next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now); } - let expected_ack_tcp_increase2 = expected_tcp_acks(cwnd_rtt_start_after_tcp); + let expected_ack_tcp_increase2 = + expected_tcp_acks(cwnd_rtt_start_after_tcp, cubic.max_datagram_size()); assert!(num_acks2 < expected_ack_tcp_increase2); // The time needed to increase cwnd by MAX_DATAGRAM_SIZE using the cubic equation will be @@ -193,10 +202,11 @@ fn tcp_phase() { #[test] fn cubic_phase() { - let mut cubic = ClassicCongestionControl::new(Cubic::default()); + let mut cubic = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)); + let cwnd_initial_f64: f64 = convert_to_f64(cubic.cwnd_initial()); // Set last_max_cwnd to a higher number make sure that cc is the cubic phase (cwnd is calculated // by the cubic equation). - cubic.set_last_max_cwnd(CWND_INITIAL_10_F64); + cubic.set_last_max_cwnd(cwnd_initial_f64 * 10.0); // Set ssthresh to something small to make sure that cc is in the congection avoidance phase. cubic.set_ssthresh(1); let mut now = now(); @@ -205,7 +215,10 @@ fn cubic_phase() { next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now); - let k = ((CWND_INITIAL_10_F64 - CWND_INITIAL_F64) / CUBIC_C / MAX_DATAGRAM_SIZE_F64).cbrt(); + let k = (cwnd_initial_f64.mul_add(10.0, -cwnd_initial_f64) + / CUBIC_C + / convert_to_f64(cubic.max_datagram_size())) + .cbrt(); let epoch_start = now; // The number of RTT until W_max is reached. @@ -213,7 +226,7 @@ fn cubic_phase() { for _ in 0..num_rtts_w_max { let cwnd_rtt_start = cubic.cwnd(); // Expected acks - let acks = cwnd_rtt_start / MAX_DATAGRAM_SIZE; + let acks = cwnd_rtt_start / cubic.max_datagram_size(); let time_increase = RTT / u32::try_from(acks).unwrap(); for _ in 0..acks { now += time_increase; @@ -223,12 +236,15 @@ fn cubic_phase() { } let expected = (CUBIC_C * ((now - epoch_start).as_secs_f64() - k).powi(3)) - .mul_add(MAX_DATAGRAM_SIZE_F64, CWND_INITIAL_10_F64) + .mul_add( + convert_to_f64(cubic.max_datagram_size()), + cwnd_initial_f64 * 10.0, + ) .round() as usize; - assert_within(cubic.cwnd(), expected, MAX_DATAGRAM_SIZE); + assert_within(cubic.cwnd(), expected, cubic.max_datagram_size()); } - assert_eq!(cubic.cwnd(), CWND_INITIAL_10); + assert_eq!(cubic.cwnd(), cubic.cwnd_initial() * 10); } fn assert_within + PartialOrd + Copy>(value: T, expected: T, margin: T) { @@ -241,7 +257,7 @@ fn assert_within + PartialOrd + Copy>(value: T, expected: T, #[test] fn congestion_event_slow_start() { - let mut cubic = ClassicCongestionControl::new(Cubic::default()); + let mut cubic = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)); _ = fill_cwnd(&mut cubic, 0, now()); ack_packet(&mut cubic, 0, now()); @@ -249,86 +265,96 @@ fn congestion_event_slow_start() { assert_within(cubic.last_max_cwnd(), 0.0, f64::EPSILON); // cwnd is increased by 1 in slow start phase, after an ack. - assert_eq!(cubic.cwnd(), CWND_INITIAL + MAX_DATAGRAM_SIZE); + assert_eq!( + cubic.cwnd(), + cubic.cwnd_initial() + cubic.max_datagram_size() + ); // Trigger a congestion_event in slow start phase packet_lost(&mut cubic, 1); // last_max_cwnd is equal to cwnd before decrease. + let cwnd_initial_f64: f64 = convert_to_f64(cubic.cwnd_initial()); assert_within( cubic.last_max_cwnd(), - CWND_INITIAL_F64 + MAX_DATAGRAM_SIZE_F64, + cwnd_initial_f64 + convert_to_f64(cubic.max_datagram_size()), f64::EPSILON, ); - assert_eq!(cubic.cwnd(), CWND_AFTER_LOSS_SLOW_START); + assert_eq!( + cubic.cwnd(), + cwnd_after_loss_slow_start(cubic.cwnd_initial(), cubic.max_datagram_size()) + ); } #[test] fn congestion_event_congestion_avoidance() { - let mut cubic = ClassicCongestionControl::new(Cubic::default()); + let mut cubic = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)); // Set ssthresh to something small to make sure that cc is in the congection avoidance phase. cubic.set_ssthresh(1); // Set last_max_cwnd to something smaller than cwnd so that the fast convergence is not // triggered. - cubic.set_last_max_cwnd(3.0 * MAX_DATAGRAM_SIZE_F64); + cubic.set_last_max_cwnd(3.0 * convert_to_f64(cubic.max_datagram_size())); _ = fill_cwnd(&mut cubic, 0, now()); ack_packet(&mut cubic, 0, now()); - assert_eq!(cubic.cwnd(), CWND_INITIAL); + assert_eq!(cubic.cwnd(), cubic.cwnd_initial()); // Trigger a congestion_event in slow start phase packet_lost(&mut cubic, 1); - assert_within(cubic.last_max_cwnd(), CWND_INITIAL_F64, f64::EPSILON); - assert_eq!(cubic.cwnd(), CWND_AFTER_LOSS); + let cwnd_initial_f64: f64 = convert_to_f64(cubic.cwnd_initial()); + assert_within(cubic.last_max_cwnd(), cwnd_initial_f64, f64::EPSILON); + assert_eq!(cubic.cwnd(), cwnd_after_loss(cubic.cwnd_initial())); } #[test] fn congestion_event_congestion_avoidance_2() { - let mut cubic = ClassicCongestionControl::new(Cubic::default()); + let mut cubic = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)); // Set ssthresh to something small to make sure that cc is in the congection avoidance phase. cubic.set_ssthresh(1); // Set last_max_cwnd to something higher than cwnd so that the fast convergence is triggered. - cubic.set_last_max_cwnd(CWND_INITIAL_10_F64); + let cwnd_initial_f64: f64 = convert_to_f64(cubic.cwnd_initial()); + cubic.set_last_max_cwnd(cwnd_initial_f64 * 10.0); _ = fill_cwnd(&mut cubic, 0, now()); ack_packet(&mut cubic, 0, now()); - assert_within(cubic.last_max_cwnd(), CWND_INITIAL_10_F64, f64::EPSILON); - assert_eq!(cubic.cwnd(), CWND_INITIAL); + assert_within(cubic.last_max_cwnd(), cwnd_initial_f64 * 10.0, f64::EPSILON); + assert_eq!(cubic.cwnd(), cubic.cwnd_initial()); // Trigger a congestion_event. packet_lost(&mut cubic, 1); assert_within( cubic.last_max_cwnd(), - CWND_INITIAL_F64 * CUBIC_FAST_CONVERGENCE, + cwnd_initial_f64 * CUBIC_FAST_CONVERGENCE, f64::EPSILON, ); - assert_eq!(cubic.cwnd(), CWND_AFTER_LOSS); + assert_eq!(cubic.cwnd(), cwnd_after_loss(cubic.cwnd_initial())); } #[test] fn congestion_event_congestion_avoidance_test_no_overflow() { const PTO: Duration = Duration::from_millis(120); - let mut cubic = ClassicCongestionControl::new(Cubic::default()); + let mut cubic = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)); // Set ssthresh to something small to make sure that cc is in the congection avoidance phase. cubic.set_ssthresh(1); // Set last_max_cwnd to something higher than cwnd so that the fast convergence is triggered. - cubic.set_last_max_cwnd(CWND_INITIAL_10_F64); + let cwnd_initial_f64: f64 = convert_to_f64(cubic.cwnd_initial()); + cubic.set_last_max_cwnd(cwnd_initial_f64 * 10.0); _ = fill_cwnd(&mut cubic, 0, now()); ack_packet(&mut cubic, 1, now()); - assert_within(cubic.last_max_cwnd(), CWND_INITIAL_10_F64, f64::EPSILON); - assert_eq!(cubic.cwnd(), CWND_INITIAL); + assert_within(cubic.last_max_cwnd(), cwnd_initial_f64 * 10.0, f64::EPSILON); + assert_eq!(cubic.cwnd(), cubic.cwnd_initial()); // Now ack packet that was send earlier. ack_packet(&mut cubic, 0, now().checked_sub(PTO).unwrap()); diff --git a/neqo-transport/src/cc/tests/new_reno.rs b/neqo-transport/src/cc/tests/new_reno.rs index a82e4995f4..1ee8c74f67 100644 --- a/neqo-transport/src/cc/tests/new_reno.rs +++ b/neqo-transport/src/cc/tests/new_reno.rs @@ -6,38 +6,40 @@ // Congestion control -use std::time::Duration; +use std::{ + net::{IpAddr, Ipv4Addr}, + time::Duration, +}; use neqo_common::IpTosEcn; use test_fixture::now; use crate::{ - cc::{ - new_reno::NewReno, ClassicCongestionControl, CongestionControl, CWND_INITIAL, - MAX_DATAGRAM_SIZE, - }, + cc::{new_reno::NewReno, ClassicCongestionControl, CongestionControl}, packet::PacketType, + pmtud::Pmtud, recovery::SentPacket, rtt::RttEstimate, }; +const IP_ADDR: IpAddr = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)); const PTO: Duration = Duration::from_millis(100); const RTT: Duration = Duration::from_millis(98); -const RTT_ESTIMATE: RttEstimate = RttEstimate::from_duration(Duration::from_millis(98)); +const RTT_ESTIMATE: RttEstimate = RttEstimate::from_duration(RTT); fn cwnd_is_default(cc: &ClassicCongestionControl) { - assert_eq!(cc.cwnd(), CWND_INITIAL); + assert_eq!(cc.cwnd(), cc.cwnd_initial()); assert_eq!(cc.ssthresh(), usize::MAX); } fn cwnd_is_halved(cc: &ClassicCongestionControl) { - assert_eq!(cc.cwnd(), CWND_INITIAL / 2); - assert_eq!(cc.ssthresh(), CWND_INITIAL / 2); + assert_eq!(cc.cwnd(), cc.cwnd_initial() / 2); + assert_eq!(cc.ssthresh(), cc.cwnd_initial() / 2); } #[test] fn issue_876() { - let mut cc = ClassicCongestionControl::new(NewReno::default()); + let mut cc = ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)); let time_now = now(); let time_before = time_now.checked_sub(Duration::from_millis(100)).unwrap(); let time_after = time_now + Duration::from_millis(150); @@ -50,7 +52,7 @@ fn issue_876() { time_before, true, Vec::new(), - MAX_DATAGRAM_SIZE - 1, + cc.max_datagram_size() - 1, ), SentPacket::new( PacketType::Short, @@ -59,7 +61,7 @@ fn issue_876() { time_before, true, Vec::new(), - MAX_DATAGRAM_SIZE - 2, + cc.max_datagram_size() - 2, ), SentPacket::new( PacketType::Short, @@ -68,7 +70,7 @@ fn issue_876() { time_before, true, Vec::new(), - MAX_DATAGRAM_SIZE, + cc.max_datagram_size(), ), SentPacket::new( PacketType::Short, @@ -77,7 +79,7 @@ fn issue_876() { time_before, true, Vec::new(), - MAX_DATAGRAM_SIZE, + cc.max_datagram_size(), ), SentPacket::new( PacketType::Short, @@ -86,7 +88,7 @@ fn issue_876() { time_before, true, Vec::new(), - MAX_DATAGRAM_SIZE, + cc.max_datagram_size(), ), SentPacket::new( PacketType::Short, @@ -95,7 +97,7 @@ fn issue_876() { time_before, true, Vec::new(), - MAX_DATAGRAM_SIZE, + cc.max_datagram_size(), ), SentPacket::new( PacketType::Short, @@ -104,7 +106,7 @@ fn issue_876() { time_after, true, Vec::new(), - MAX_DATAGRAM_SIZE - 3, + cc.max_datagram_size() - 3, ), ]; @@ -114,7 +116,7 @@ fn issue_876() { } assert_eq!(cc.acked_bytes(), 0); cwnd_is_default(&cc); - assert_eq!(cc.bytes_in_flight(), 6 * MAX_DATAGRAM_SIZE - 3); + assert_eq!(cc.bytes_in_flight(), 6 * cc.max_datagram_size() - 3); cc.on_packets_lost(Some(time_now), None, PTO, &sent_packets[0..1]); @@ -122,35 +124,36 @@ fn issue_876() { assert!(cc.recovery_packet()); assert_eq!(cc.acked_bytes(), 0); cwnd_is_halved(&cc); - assert_eq!(cc.bytes_in_flight(), 5 * MAX_DATAGRAM_SIZE - 2); + assert_eq!(cc.bytes_in_flight(), 5 * cc.max_datagram_size() - 2); // Send a packet after recovery starts cc.on_packet_sent(&sent_packets[6]); assert!(!cc.recovery_packet()); cwnd_is_halved(&cc); assert_eq!(cc.acked_bytes(), 0); - assert_eq!(cc.bytes_in_flight(), 6 * MAX_DATAGRAM_SIZE - 5); + assert_eq!(cc.bytes_in_flight(), 6 * cc.max_datagram_size() - 5); // and ack it. cwnd increases slightly cc.on_packets_acked(&sent_packets[6..], &RTT_ESTIMATE, time_now); assert_eq!(cc.acked_bytes(), sent_packets[6].len()); cwnd_is_halved(&cc); - assert_eq!(cc.bytes_in_flight(), 5 * MAX_DATAGRAM_SIZE - 2); + assert_eq!(cc.bytes_in_flight(), 5 * cc.max_datagram_size() - 2); // Packet from before is lost. Should not hurt cwnd. cc.on_packets_lost(Some(time_now), None, PTO, &sent_packets[1..2]); assert!(!cc.recovery_packet()); assert_eq!(cc.acked_bytes(), sent_packets[6].len()); cwnd_is_halved(&cc); - assert_eq!(cc.bytes_in_flight(), 4 * MAX_DATAGRAM_SIZE); + assert_eq!(cc.bytes_in_flight(), 4 * cc.max_datagram_size()); } #[test] // https://github.com/mozilla/neqo/pull/1465 fn issue_1465() { - let mut cc = ClassicCongestionControl::new(NewReno::default()); + let mut cc = ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)); let mut pn = 0; let mut now = now(); + let max_datagram_size = cc.max_datagram_size(); let mut next_packet = |now| { let p = SentPacket::new( PacketType::Short, @@ -159,7 +162,7 @@ fn issue_1465() { now, true, Vec::new(), - MAX_DATAGRAM_SIZE, + max_datagram_size, ); pn += 1; p @@ -176,7 +179,7 @@ fn issue_1465() { assert_eq!(cc.acked_bytes(), 0); cwnd_is_default(&cc); - assert_eq!(cc.bytes_in_flight(), 3 * MAX_DATAGRAM_SIZE); + assert_eq!(cc.bytes_in_flight(), 3 * cc.max_datagram_size()); // advance one rtt to detect lost packet there this simplifies the timers, because // on_packet_loss would only be called after RTO, but that is not relevant to the problem @@ -187,13 +190,13 @@ fn issue_1465() { assert!(cc.recovery_packet()); assert_eq!(cc.acked_bytes(), 0); cwnd_is_halved(&cc); - assert_eq!(cc.bytes_in_flight(), 2 * MAX_DATAGRAM_SIZE); + assert_eq!(cc.bytes_in_flight(), 2 * cc.max_datagram_size()); // Don't reduce the cwnd again on second packet loss cc.on_packets_lost(Some(now), None, PTO, &[p3]); assert_eq!(cc.acked_bytes(), 0); cwnd_is_halved(&cc); // still the same as after first packet loss - assert_eq!(cc.bytes_in_flight(), MAX_DATAGRAM_SIZE); + assert_eq!(cc.bytes_in_flight(), cc.max_datagram_size()); // the acked packets before on_packet_sent were the cause of // https://github.com/mozilla/neqo/pull/1465 @@ -219,7 +222,7 @@ fn issue_1465() { assert!(cc.recovery_packet()); assert_eq!(cc.cwnd(), cur_cwnd / 2); assert_eq!(cc.acked_bytes(), 0); - assert_eq!(cc.bytes_in_flight(), 2 * MAX_DATAGRAM_SIZE); + assert_eq!(cc.bytes_in_flight(), 2 * cc.max_datagram_size()); // this shouldn't introduce further cwnd reduction, but it did before https://github.com/mozilla/neqo/pull/1465 cc.on_packets_lost(Some(now), None, PTO, &[p6]); diff --git a/neqo-transport/src/connection/dump.rs b/neqo-transport/src/connection/dump.rs index 12d337c570..22d4ede474 100644 --- a/neqo-transport/src/connection/dump.rs +++ b/neqo-transport/src/connection/dump.rs @@ -18,7 +18,7 @@ use crate::{ path::PathRef, }; -#[allow(clippy::module_name_repetitions)] +#[allow(clippy::too_many_arguments)] pub fn dump_packet( conn: &Connection, path: &PathRef, @@ -27,6 +27,7 @@ pub fn dump_packet( pn: PacketNumber, payload: &[u8], tos: IpTos, + len: usize, ) { if log::STATIC_MAX_LEVEL == log::LevelFilter::Off || !log::log_enabled!(log::Level::Debug) { return; @@ -46,11 +47,12 @@ pub fn dump_packet( } qdebug!( [conn], - "pn={} type={:?} {} {:?}{}", + "pn={} type={:?} {} {:?} len {}{}", pn, pt, path.borrow(), tos, + len, s ); } diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index 2ff2f7cc26..8e862b58aa 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -1521,6 +1521,7 @@ impl Connection { payload.pn(), &payload[..], d.tos(), + d.len(), ); #[cfg(feature = "build-fuzzing-corpus")] @@ -2110,6 +2111,7 @@ impl Connection { space: PacketNumberSpace, profile: &SendProfile, builder: &mut PacketBuilder, + coalesced: bool, // Whether this packet is coalesced behind another one. now: Instant, ) -> (Vec, bool, bool) { let mut tokens = Vec::new(); @@ -2129,11 +2131,11 @@ impl Connection { } let ack_end = builder.len(); - // Avoid sending probes until the handshake completes, + // Avoid sending path validation probes until the handshake completes, // but send them even when we don't have space. - let full_mtu = profile.limit() == path.borrow().mtu(); + let full_mtu = profile.limit() == path.borrow().plpmtu(); if space == PacketNumberSpace::ApplicationData && self.state.connected() { - // Probes should only be padded if the full MTU is available. + // Path validation probes should only be padded if the full MTU is available. // The probing code needs to know so it can track that. if path.borrow_mut().write_frames( builder, @@ -2152,6 +2154,16 @@ impl Connection { if primary { if space == PacketNumberSpace::ApplicationData { + if self.state.connected() + && path.borrow().pmtud().needs_probe() + && !coalesced // Only send PMTUD probes using non-coalesced packets. + && full_mtu + { + path.borrow_mut() + .pmtud_mut() + .send_probe(builder, &mut self.stats.borrow_mut()); + ack_eliciting = true; + } self.write_appdata_frames(builder, &mut tokens); } else { let stats = &mut self.stats.borrow_mut().frame_tx; @@ -2231,7 +2243,6 @@ impl Connection { let version = self.version(); // Determine how we are sending packets (PTO, etc..). - let mtu = path.borrow().mtu(); let profile = self.loss_recovery.send_profile(&path.borrow(), now); qdebug!([self], "output_path send_profile {:?}", profile); @@ -2267,9 +2278,16 @@ impl Connection { // Configure the limits and padding for this packet. let aead_expansion = tx.expansion(); - builder.set_limit(profile.limit() - aead_expansion); + needs_padding |= builder.set_initial_limit( + &profile, + aead_expansion, + self.paths + .primary() + .ok_or(Error::InternalError)? + .borrow() + .pmtud(), + ); builder.enable_padding(needs_padding); - debug_assert!(builder.limit() <= 2048); if builder.is_full() { encoder = builder.abort(); break; @@ -2282,7 +2300,7 @@ impl Connection { self.write_closing_frames(close, &mut builder, *space, now, path, &mut tokens); } else { (tokens, ack_eliciting, padded) = - self.write_frames(path, *space, &profile, &mut builder, now); + self.write_frames(path, *space, &profile, &mut builder, header_start != 0, now); } if builder.packet_empty() { // Nothing to include in this packet. @@ -2298,6 +2316,7 @@ impl Connection { pn, &builder.as_ref()[payload_start..], path.borrow().tos(), + builder.len() + aead_expansion, ); qlog::packet_sent( &self.qlog, @@ -2310,7 +2329,6 @@ impl Connection { self.stats.borrow_mut().packets_tx += 1; let tx = self.crypto.states.tx_mut(self.version, cspace).unwrap(); encoder = builder.build(tx)?; - debug_assert!(encoder.len() <= mtu); self.crypto.states.auto_update()?; if ack_eliciting { @@ -2360,14 +2378,14 @@ impl Connection { if needs_padding { qdebug!( [self], - "pad Initial from {} to path MTU {}", + "pad Initial from {} to PLPMTU {}", packets.len(), - mtu + profile.limit() ); - initial.track_padding(mtu - packets.len()); + initial.track_padding(profile.limit() - packets.len()); // These zeros aren't padding frames, they are an invalid all-zero coalesced // packet, which is why we don't increase `frame_tx.padding` count here. - packets.resize(mtu, 0); + packets.resize(profile.limit(), 0); } self.loss_recovery.on_packet_sent(path, initial); } @@ -2697,6 +2715,19 @@ impl Connection { Ok(()) } + fn set_confirmed(&mut self) -> Res<()> { + self.set_state(State::Confirmed); + if self.conn_params.pmtud_enabled() { + self.paths + .primary() + .ok_or(Error::InternalError)? + .borrow_mut() + .pmtud_mut() + .start(); + } + Ok(()) + } + #[allow(clippy::too_many_lines)] // Yep, but it's a nice big match, which is basically lots of little functions. fn input_frame( &mut self, @@ -2851,7 +2882,7 @@ impl Connection { if self.role == Role::Server || !self.state.connected() { return Err(Error::ProtocolViolation); } - self.set_state(State::Confirmed); + self.set_confirmed()?; self.discard_keys(PacketNumberSpace::Handshake, now); self.migrate_to_preferred_address(now)?; } @@ -3026,7 +3057,7 @@ impl Connection { self.stats.borrow_mut().resumed = self.crypto.tls.info().unwrap().resumed(); if self.role == Role::Server { self.state_signaling.handshake_done(); - self.set_state(State::Confirmed); + self.set_confirmed()?; } qinfo!([self], "Connection established"); Ok(()) @@ -3298,7 +3329,7 @@ impl Connection { return Err(Error::NotAvailable); }; let path = self.paths.primary().ok_or(Error::NotAvailable)?; - let mtu = path.borrow().mtu(); + let mtu = path.borrow().plpmtu(); let encoder = Encoder::with_capacity(mtu); let (_, mut builder) = Self::build_packet_header( @@ -3338,6 +3369,17 @@ impl Connection { self.quic_datagrams .add_datagram(buf, id.into(), &mut self.stats.borrow_mut()) } + + /// Return the PLMTU of the primary path. + /// + /// # Panics + /// + /// The function panics if there is no primary path. (Should be fine for test usage.) + #[cfg(test)] + #[must_use] + pub fn plpmtu(&self) -> usize { + self.paths.primary().unwrap().borrow().plpmtu() + } } impl EventProvider for Connection { diff --git a/neqo-transport/src/connection/params.rs b/neqo-transport/src/connection/params.rs index 3263e3b734..e305771ff4 100644 --- a/neqo-transport/src/connection/params.rs +++ b/neqo-transport/src/connection/params.rs @@ -79,6 +79,8 @@ pub struct ConnectionParameters { fast_pto: u8, grease: bool, pacing: bool, + /// Whether the connection performs PLPMTUD. + pmtud: bool, } impl Default for ConnectionParameters { @@ -101,6 +103,7 @@ impl Default for ConnectionParameters { fast_pto: FAST_PTO_SCALE, grease: true, pacing: true, + pmtud: false, } } } @@ -344,6 +347,17 @@ impl ConnectionParameters { self } + #[must_use] + pub const fn pmtud_enabled(&self) -> bool { + self.pmtud + } + + #[must_use] + pub const fn pmtud(mut self, pmtud: bool) -> Self { + self.pmtud = pmtud; + self + } + /// # Errors /// When a connection ID cannot be obtained. /// # Panics diff --git a/neqo-transport/src/connection/tests/cc.rs b/neqo-transport/src/connection/tests/cc.rs index f21f4e184f..2f66774881 100644 --- a/neqo-transport/src/connection/tests/cc.rs +++ b/neqo-transport/src/connection/tests/cc.rs @@ -14,12 +14,13 @@ use super::{ CLIENT_HANDSHAKE_1RTT_PACKETS, DEFAULT_RTT, POST_HANDSHAKE_CWND, }; use crate::{ - cc::MAX_DATAGRAM_SIZE, + connection::tests::{connect_with_rtt, new_client, new_server, now}, packet::PacketNumber, recovery::{ACK_ONLY_SIZE_LIMIT, PACKET_THRESHOLD}, sender::PACING_BURST_SIZE, stream_id::StreamType, tracking::DEFAULT_ACK_PACKET_TOLERANCE, + ConnectionParameters, }; #[test] @@ -32,7 +33,22 @@ fn cc_slow_start() { // Try to send a lot of data let stream_id = client.stream_create(StreamType::UniDi).unwrap(); let (c_tx_dgrams, _) = fill_cwnd(&mut client, stream_id, now); - assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND, client.plpmtu()); + assert!(cwnd_avail(&client) < ACK_ONLY_SIZE_LIMIT); +} + +#[test] +fn cc_slow_start_pmtud() { + let mut client = new_client(ConnectionParameters::default().pmtud(true)); + let mut server = new_server(ConnectionParameters::default().pmtud(true)); + let now = connect_with_rtt(&mut client, &mut server, now(), DEFAULT_RTT); + + // Try to send a lot of data + let stream_id = client.stream_create(StreamType::UniDi).unwrap(); + let cwnd = cwnd_avail(&client); + let (dgrams, _) = fill_cwnd(&mut client, stream_id, now); + let dgrams_len = dgrams.iter().map(|d| d.len()).sum::(); + assert_eq!(dgrams_len, cwnd); assert!(cwnd_avail(&client) < ACK_ONLY_SIZE_LIMIT); } @@ -53,7 +69,7 @@ fn cc_slow_start_to_cong_avoidance_recovery_period(congestion_signal: Congestion // Buffer up lot of data and generate packets let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, stream_id, now); - assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND, client.plpmtu()); // Predict the packet number of the last packet sent. // We have already sent packets in `connect_rtt_idle`, // so include a fudge factor. @@ -79,7 +95,7 @@ fn cc_slow_start_to_cong_avoidance_recovery_period(congestion_signal: Congestion // Client: send more let (mut c_tx_dgrams, mut now) = fill_cwnd(&mut client, stream_id, now); - assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND * 2); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND * 2, client.plpmtu()); let flight2_largest = flight1_largest + u64::try_from(c_tx_dgrams.len()).unwrap(); // Server: Receive and generate ack again, but this time add congestion @@ -135,7 +151,7 @@ fn cc_cong_avoidance_recovery_period_unchanged() { // Buffer up lot of data and generate packets let (mut c_tx_dgrams, now) = fill_cwnd(&mut client, stream_id, now); - assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND, client.plpmtu()); // Drop 0th packet. When acked, this should put client into CARP. c_tx_dgrams.remove(0); @@ -267,7 +283,7 @@ fn cc_cong_avoidance_recovery_period_to_cong_avoidance() { now = next_now; next_c_tx_dgrams.append(&mut new_pkts); - expected_cwnd += MAX_DATAGRAM_SIZE; + expected_cwnd += client.plpmtu(); assert_eq!(cwnd(&client), expected_cwnd); c_tx_dgrams = next_c_tx_dgrams; } @@ -284,7 +300,7 @@ fn cc_slow_start_to_persistent_congestion_no_acks() { // Buffer up lot of data and generate packets let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, stream, now); - assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND, client.plpmtu()); // Server: Receive and generate ack now += DEFAULT_RTT / 2; @@ -306,7 +322,7 @@ fn cc_slow_start_to_persistent_congestion_some_acks() { // Buffer up lot of data and generate packets let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, stream, now); - assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND, client.plpmtu()); // Server: Receive and generate ack now += Duration::from_millis(100); @@ -335,7 +351,7 @@ fn cc_persistent_congestion_to_slow_start() { // Buffer up lot of data and generate packets let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, stream, now); - assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND, client.plpmtu()); // Server: Receive and generate ack now += Duration::from_millis(10); @@ -378,7 +394,7 @@ fn ack_are_not_cc() { // Buffer up lot of data and generate packets, so that cc window is filled. let (c_tx_dgrams, now) = fill_cwnd(&mut client, stream, now); - assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND, client.plpmtu()); // The server hasn't received any of these packets yet, the server // won't ACK, but if it sends an ack-eliciting packet instead. @@ -431,7 +447,7 @@ fn pace() { } let gap = client.process_output(now).callback(); assert_ne!(gap, Duration::new(0, 0)); - for _ in (1 + PACING_BURST_SIZE)..cwnd_packets(POST_HANDSHAKE_CWND) { + for _ in (1 + PACING_BURST_SIZE)..cwnd_packets(POST_HANDSHAKE_CWND, client.plpmtu()) { match client.process_output(now) { Output::Callback(t) => assert_eq!(t, gap), Output::Datagram(_) => { @@ -448,7 +464,7 @@ fn pace() { } let dgram = client.process_output(now).dgram(); assert!(dgram.is_none()); - assert_eq!(count, cwnd_packets(POST_HANDSHAKE_CWND)); + assert_eq!(count, cwnd_packets(POST_HANDSHAKE_CWND, client.plpmtu())); let fin = client.process_output(now).callback(); assert_ne!(fin, Duration::new(0, 0)); assert_ne!(fin, gap); diff --git a/neqo-transport/src/connection/tests/datagram.rs b/neqo-transport/src/connection/tests/datagram.rs index 864f6e05b6..2688fc0cc8 100644 --- a/neqo-transport/src/connection/tests/datagram.rs +++ b/neqo-transport/src/connection/tests/datagram.rs @@ -7,6 +7,7 @@ use std::{cell::RefCell, rc::Rc}; use neqo_common::event::Provider; +use static_assertions::const_assert; use test_fixture::now; use super::{ @@ -14,20 +15,27 @@ use super::{ AT_LEAST_PTO, }; use crate::{ + connection::tests::DEFAULT_ADDR, events::{ConnectionEvent, OutgoingDatagramOutcome}, frame::FRAME_TYPE_DATAGRAM, packet::PacketBuilder, quic_datagrams::MAX_QUIC_DATAGRAM, send_stream::{RetransmissionPriority, TransmissionPriority}, - CloseReason, Connection, ConnectionParameters, Error, StreamType, MIN_INITIAL_PACKET_SIZE, + CloseReason, Connection, ConnectionParameters, Error, Pmtud, StreamType, + MIN_INITIAL_PACKET_SIZE, }; -const DATAGRAM_LEN_MTU: u64 = 1310; -const DATA_MTU: &[u8] = &[1; 1310]; -const DATA_BIGGER_THAN_MTU: &[u8] = &[0; 2620]; +// FIXME: The 27 here is a magic constant that the original code also (implicitly) had. +const DATAGRAM_LEN_MTU: usize = Pmtud::default_plpmtu(DEFAULT_ADDR.ip()) - 27; +const DATA_MTU: &[u8] = &[1; DATAGRAM_LEN_MTU]; +const DATA_BIGGER_THAN_MTU: &[u8] = &[0; 2 * DATAGRAM_LEN_MTU]; +const_assert!(DATA_BIGGER_THAN_MTU.len() > DATAGRAM_LEN_MTU); const DATAGRAM_LEN_SMALLER_THAN_MTU: u64 = MIN_INITIAL_PACKET_SIZE as u64; +const_assert!(DATAGRAM_LEN_SMALLER_THAN_MTU < DATAGRAM_LEN_MTU as u64); const DATA_SMALLER_THAN_MTU: &[u8] = &[0; MIN_INITIAL_PACKET_SIZE]; -const DATA_SMALLER_THAN_MTU_2: &[u8] = &[0; 600]; +const_assert!(DATA_SMALLER_THAN_MTU.len() < DATAGRAM_LEN_MTU); +const DATA_SMALLER_THAN_MTU_2: &[u8] = &[0; MIN_INITIAL_PACKET_SIZE / 2]; +const_assert!(DATA_SMALLER_THAN_MTU_2.len() < DATA_SMALLER_THAN_MTU.len()); const OUTGOING_QUEUE: usize = 2; struct InsertDatagram<'a> { @@ -132,15 +140,20 @@ fn connect_datagram() -> (Connection, Connection) { fn mtu_limit() { let (client, server) = connect_datagram(); - assert_eq!(client.max_datagram_size(), Ok(DATAGRAM_LEN_MTU)); - assert_eq!(server.max_datagram_size(), Ok(DATAGRAM_LEN_MTU)); + assert_eq!( + client.max_datagram_size(), + Ok((DATAGRAM_LEN_MTU).try_into().unwrap()) + ); + assert_eq!( + server.max_datagram_size(), + Ok((DATAGRAM_LEN_MTU).try_into().unwrap()) + ); } #[test] fn limit_data_size() { let (mut client, mut server) = connect_datagram(); - assert!(u64::try_from(DATA_BIGGER_THAN_MTU.len()).unwrap() > DATAGRAM_LEN_MTU); // Datagram can be queued because they are smaller than allowed by the peer, // but they cannot be sent. assert_eq!(server.send_datagram(DATA_BIGGER_THAN_MTU, Some(1)), Ok(())); @@ -173,7 +186,6 @@ fn limit_data_size() { fn after_dgram_dropped_continue_writing_frames() { let (mut client, _) = connect_datagram(); - assert!(u64::try_from(DATA_BIGGER_THAN_MTU.len()).unwrap() > DATAGRAM_LEN_MTU); // Datagram can be queued because they are smaller than allowed by the peer, // but they cannot be sent. assert_eq!(client.send_datagram(DATA_BIGGER_THAN_MTU, Some(1)), Ok(())); @@ -377,7 +389,6 @@ fn dgram_too_big() { let mut server = default_server(); connect_force_idle(&mut client, &mut server); - assert!(DATAGRAM_LEN_MTU > DATAGRAM_LEN_SMALLER_THAN_MTU); server.test_frame_writer = Some(Box::new(InsertDatagram { data: DATA_MTU })); let out = server.process_output(now()).dgram().unwrap(); server.test_frame_writer = None; @@ -589,7 +600,7 @@ fn datagram_fill() { let path = p.borrow(); // Minimum overhead is connection ID length, 1 byte short header, 1 byte packet number, // 1 byte for the DATAGRAM frame type, and 16 bytes for the AEAD. - path.mtu() - path.remote_cid().len() - 19 + path.plpmtu() - path.remote_cid().len() - 19 }; assert!(space >= 64); // Unlikely, but this test depends on the datagram being this large. diff --git a/neqo-transport/src/connection/tests/handshake.rs b/neqo-transport/src/connection/tests/handshake.rs index 0165fa70eb..fce6e0bf97 100644 --- a/neqo-transport/src/connection/tests/handshake.rs +++ b/neqo-transport/src/connection/tests/handshake.rs @@ -29,31 +29,33 @@ use super::{ CountingConnectionIdGenerator, AT_LEAST_PTO, DEFAULT_RTT, DEFAULT_STREAM_DATA, }; use crate::{ - connection::AddressValidation, + connection::{ + tests::{new_client, new_server}, + AddressValidation, + }, events::ConnectionEvent, - path::PATH_MTU_V6, server::ValidateAddress, tparams::{TransportParameter, MIN_ACK_DELAY}, tracking::DEFAULT_ACK_DELAY, - CloseReason, ConnectionParameters, EmptyConnectionIdGenerator, Error, StreamType, Version, + CloseReason, ConnectionParameters, EmptyConnectionIdGenerator, Error, Pmtud, StreamType, + Version, }; const ECH_CONFIG_ID: u8 = 7; const ECH_PUBLIC_NAME: &str = "public.example"; -#[test] -fn full_handshake() { +fn full_handshake(pmtud: bool) { qdebug!("---- client: generate CH"); - let mut client = default_client(); + let mut client = new_client(ConnectionParameters::default().pmtud(pmtud)); let out = client.process(None, now()); assert!(out.as_dgram_ref().is_some()); - assert_eq!(out.as_dgram_ref().unwrap().len(), PATH_MTU_V6); + assert_eq!(out.as_dgram_ref().unwrap().len(), client.plpmtu()); qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN"); - let mut server = default_server(); + let mut server = new_server(ConnectionParameters::default().pmtud(pmtud)); let out = server.process(out.as_dgram_ref(), now()); assert!(out.as_dgram_ref().is_some()); - assert_eq!(out.as_dgram_ref().unwrap().len(), PATH_MTU_V6); + assert_eq!(out.as_dgram_ref().unwrap().len(), server.plpmtu()); qdebug!("---- client: cert verification"); let out = client.process(out.as_dgram_ref(), now()); @@ -76,10 +78,26 @@ fn full_handshake() { qdebug!("---- client: ACKS -> 0"); let out = client.process(out.as_dgram_ref(), now()); - assert!(out.as_dgram_ref().is_none()); + if pmtud { + // PMTUD causes a PING probe to be sent here + let pkt = out.dgram().unwrap(); + assert!(pkt.len() > client.plpmtu()); + } else { + assert!(out.as_dgram_ref().is_none()); + } assert_eq!(*client.state(), State::Confirmed); } +#[test] +fn handshake_no_pmtud() { + full_handshake(false); +} + +#[test] +fn handshake_pmtud() { + full_handshake(true); +} + #[test] fn handshake_failed_authentication() { qdebug!("---- client: generate CH"); @@ -143,7 +161,7 @@ fn dup_server_flight1() { let mut client = default_client(); let out = client.process(None, now()); assert!(out.as_dgram_ref().is_some()); - assert_eq!(out.as_dgram_ref().unwrap().len(), PATH_MTU_V6); + assert_eq!(out.as_dgram_ref().unwrap().len(), client.plpmtu()); qdebug!("Output={:0x?}", out.as_dgram_ref()); qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN"); @@ -267,7 +285,7 @@ fn send_05rtt() { let c1 = client.process(None, now()).dgram(); assert!(c1.is_some()); let s1 = server.process(c1.as_ref(), now()).dgram().unwrap(); - assert_eq!(s1.len(), PATH_MTU_V6); + assert_eq!(s1.len(), server.plpmtu()); // The server should accept writes at this point. let s2 = send_something(&mut server, now()); @@ -437,7 +455,7 @@ fn coalesce_05rtt() { let s2 = server.process(c2.as_ref(), now).dgram(); // Even though there is a 1-RTT packet at the end of the datagram, the // flight should be padded to full size. - assert_eq!(s2.as_ref().unwrap().len(), PATH_MTU_V6); + assert_eq!(s2.as_ref().unwrap().len(), server.plpmtu()); // The client should process the datagram. It can't process the 1-RTT // packet until authentication completes though. So it saves it. @@ -645,7 +663,7 @@ fn verify_pkt_honors_mtu() { assert_eq!(client.stream_send(stream_id, &[0xbb; 2000]).unwrap(), 2000); let pkt0 = client.process(None, now); assert!(matches!(pkt0, Output::Datagram(_))); - assert_eq!(pkt0.as_dgram_ref().unwrap().len(), PATH_MTU_V6); + assert_eq!(pkt0.as_dgram_ref().unwrap().len(), client.plpmtu()); } #[test] @@ -759,15 +777,15 @@ fn anti_amplification() { // With a gigantic transport parameter, the server is unable to complete // the handshake within the amplification limit. - let very_big = TransportParameter::Bytes(vec![0; PATH_MTU_V6 * 3]); + let very_big = TransportParameter::Bytes(vec![0; Pmtud::default_plpmtu(DEFAULT_ADDR.ip()) * 3]); server.set_local_tparam(0xce16, very_big).unwrap(); let c_init = client.process_output(now).dgram(); now += DEFAULT_RTT / 2; let s_init1 = server.process(c_init.as_ref(), now).dgram().unwrap(); - assert_eq!(s_init1.len(), PATH_MTU_V6); + assert_eq!(s_init1.len(), client.plpmtu()); let s_init2 = server.process_output(now).dgram().unwrap(); - assert_eq!(s_init2.len(), PATH_MTU_V6); + assert_eq!(s_init2.len(), server.plpmtu()); // Skip the gap for pacing here. let s_pacing = server.process_output(now).callback(); @@ -775,7 +793,7 @@ fn anti_amplification() { now += s_pacing; let s_init3 = server.process_output(now).dgram().unwrap(); - assert_eq!(s_init3.len(), PATH_MTU_V6); + assert_eq!(s_init3.len(), server.plpmtu()); let cb = server.process_output(now).callback(); assert_ne!(cb, Duration::new(0, 0)); @@ -790,7 +808,7 @@ fn anti_amplification() { // The client sends a padded datagram, with just ACK for Handshake. assert_eq!(client.stats().frame_tx.ack, ack_count + 1); assert_eq!(client.stats().frame_tx.all, frame_count + 1); - assert_ne!(ack.len(), PATH_MTU_V6); // Not padded (it includes Handshake). + assert_ne!(ack.len(), client.plpmtu()); // Not padded (it includes Handshake). now += DEFAULT_RTT / 2; let remainder = server.process(Some(&ack), now).dgram(); diff --git a/neqo-transport/src/connection/tests/keys.rs b/neqo-transport/src/connection/tests/keys.rs index c2ae9529bf..ca35b8b774 100644 --- a/neqo-transport/src/connection/tests/keys.rs +++ b/neqo-transport/src/connection/tests/keys.rs @@ -20,7 +20,6 @@ use super::{ use crate::{ crypto::{OVERWRITE_INVOCATIONS, UPDATE_WRITE_KEYS_AT}, packet::PacketNumber, - path::PATH_MTU_V6, }; fn check_discarded( @@ -60,7 +59,7 @@ fn discarded_initial_keys() { let mut client = default_client(); let init_pkt_c = client.process(None, now()).dgram(); assert!(init_pkt_c.is_some()); - assert_eq!(init_pkt_c.as_ref().unwrap().len(), PATH_MTU_V6); + assert_eq!(init_pkt_c.as_ref().unwrap().len(), client.plpmtu()); qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN"); let mut server = default_server(); diff --git a/neqo-transport/src/connection/tests/migration.rs b/neqo-transport/src/connection/tests/migration.rs index 17f3e549fa..3ee88943dd 100644 --- a/neqo-transport/src/connection/tests/migration.rs +++ b/neqo-transport/src/connection/tests/migration.rs @@ -28,7 +28,7 @@ use crate::{ connection::tests::send_something_paced, frame::FRAME_TYPE_NEW_CONNECTION_ID, packet::PacketBuilder, - path::{PATH_MTU_V4, PATH_MTU_V6}, + pmtud::Pmtud, tparams::{self, PreferredAddress, TransportParameter}, CloseReason, ConnectionId, ConnectionIdDecoder, ConnectionIdGenerator, ConnectionIdRef, ConnectionParameters, EmptyConnectionIdGenerator, Error, @@ -464,10 +464,7 @@ fn fast_handshake(client: &mut Connection, server: &mut Connection) -> Option PATH_MTU_V4, - IpAddr::V6(_) => PATH_MTU_V6, - }; + let mtu = Pmtud::default_plpmtu(hs_client.ip()); let assert_orig_path = |d: &Datagram, full_mtu: bool| { assert_eq!( d.destination(), diff --git a/neqo-transport/src/connection/tests/mod.rs b/neqo-transport/src/connection/tests/mod.rs index 2fb70881b1..e156d8c0c1 100644 --- a/neqo-transport/src/connection/tests/mod.rs +++ b/neqo-transport/src/connection/tests/mod.rs @@ -20,12 +20,12 @@ use test_fixture::{fixture_init, new_neqo_qlog, now, DEFAULT_ADDR}; use super::{CloseReason, Connection, ConnectionId, Output, State}; use crate::{ addr_valid::{AddressValidation, ValidateAddress}, - cc::{CWND_INITIAL_PKTS, CWND_MIN}, + cc::CWND_INITIAL_PKTS, cid::ConnectionIdRef, events::ConnectionEvent, frame::FRAME_TYPE_PING, packet::PacketBuilder, - path::PATH_MTU_V6, + pmtud::Pmtud, recovery::ACK_ONLY_SIZE_LIMIT, stats::{FrameStats, Stats, MAX_PTO_COUNTS}, ConnectionIdDecoder, ConnectionIdGenerator, ConnectionParameters, Error, StreamId, StreamType, @@ -375,24 +375,23 @@ fn fill_stream(c: &mut Connection, stream: StreamId) { /// pacing, this looks at the congestion window to tell when to stop. /// Returns a list of datagrams and the new time. fn fill_cwnd(c: &mut Connection, stream: StreamId, mut now: Instant) -> (Vec, Instant) { - // Train wreck function to get the remaining congestion window on the primary path. - fn cwnd(c: &Connection) -> usize { - c.paths.primary().unwrap().borrow().sender().cwnd_avail() - } - - qtrace!("fill_cwnd starting cwnd: {}", cwnd(c)); + qtrace!("fill_cwnd starting cwnd: {}", cwnd_avail(c)); fill_stream(c, stream); let mut total_dgrams = Vec::new(); loop { let pkt = c.process_output(now); - qtrace!("fill_cwnd cwnd remaining={}, output: {:?}", cwnd(c), pkt); + qtrace!( + "fill_cwnd cwnd remaining={}, output: {:?}", + cwnd_avail(c), + pkt + ); match pkt { Output::Datagram(dgram) => { total_dgrams.push(dgram); } Output::Callback(t) => { - if cwnd(c) < ACK_ONLY_SIZE_LIMIT { + if cwnd_avail(c) < ACK_ONLY_SIZE_LIMIT { break; } now += t; @@ -477,10 +476,15 @@ where fn cwnd(c: &Connection) -> usize { c.paths.primary().unwrap().borrow().sender().cwnd() } + fn cwnd_avail(c: &Connection) -> usize { c.paths.primary().unwrap().borrow().sender().cwnd_avail() } +fn cwnd_min(c: &Connection) -> usize { + c.paths.primary().unwrap().borrow().sender().cwnd_min() +} + fn induce_persistent_congestion( client: &mut Connection, server: &mut Connection, @@ -526,7 +530,7 @@ fn induce_persistent_congestion( // An ACK for the third PTO causes persistent congestion. let s_ack = ack_bytes(server, stream, c_tx_dgrams, now); client.process_input(&s_ack, now); - assert_eq!(cwnd(client), CWND_MIN); + assert_eq!(cwnd(client), cwnd_min(client)); now } @@ -539,30 +543,30 @@ fn induce_persistent_congestion( /// value could fail as a result of variations, so it's OK to just /// change this value, but it is good to first understand where the /// change came from. -const POST_HANDSHAKE_CWND: usize = PATH_MTU_V6 * CWND_INITIAL_PKTS; +const POST_HANDSHAKE_CWND: usize = Pmtud::default_plpmtu(DEFAULT_ADDR.ip()) * CWND_INITIAL_PKTS; /// Determine the number of packets required to fill the CWND. -const fn cwnd_packets(data: usize) -> usize { +const fn cwnd_packets(data: usize, mtu: usize) -> usize { // Add one if the last chunk is >= ACK_ONLY_SIZE_LIMIT. - (data + PATH_MTU_V6 - ACK_ONLY_SIZE_LIMIT) / PATH_MTU_V6 + (data + mtu - ACK_ONLY_SIZE_LIMIT) / mtu } /// Determine the size of the last packet. /// The minimal size of a packet is `ACK_ONLY_SIZE_LIMIT`. -const fn last_packet(cwnd: usize) -> usize { - if (cwnd % PATH_MTU_V6) > ACK_ONLY_SIZE_LIMIT { - cwnd % PATH_MTU_V6 +const fn last_packet(cwnd: usize, mtu: usize) -> usize { + if (cwnd % mtu) > ACK_ONLY_SIZE_LIMIT { + cwnd % mtu } else { - PATH_MTU_V6 + mtu } } /// Assert that the set of packets fill the CWND. -fn assert_full_cwnd(packets: &[Datagram], cwnd: usize) { - assert_eq!(packets.len(), cwnd_packets(cwnd)); +fn assert_full_cwnd(packets: &[Datagram], cwnd: usize, mtu: usize) { + assert_eq!(packets.len(), cwnd_packets(cwnd, mtu)); let (last, rest) = packets.split_last().unwrap(); - assert!(rest.iter().all(|d| d.len() == PATH_MTU_V6)); - assert_eq!(last.len(), last_packet(cwnd)); + assert!(rest.iter().all(|d| d.len() == mtu)); + assert_eq!(last.len(), last_packet(cwnd, mtu)); } /// Send something on a stream from `sender` to `receiver`, maybe allowing for pacing. diff --git a/neqo-transport/src/connection/tests/priority.rs b/neqo-transport/src/connection/tests/priority.rs index 079ba93b9f..26ba55260d 100644 --- a/neqo-transport/src/connection/tests/priority.rs +++ b/neqo-transport/src/connection/tests/priority.rs @@ -386,7 +386,7 @@ fn low() { // Send a session ticket and make it big enough to require a whole packet. // The resulting CRYPTO frame beats out the stream data. let stats_before = server.stats().frame_tx; - server.send_ticket(now, &[0; 2048]).unwrap(); + server.send_ticket(now, &vec![0; server.plpmtu()]).unwrap(); mem::drop(server.process_output(now)); let stats_after = server.stats().frame_tx; assert_eq!(stats_after.crypto, stats_before.crypto + 1); diff --git a/neqo-transport/src/connection/tests/recovery.rs b/neqo-transport/src/connection/tests/recovery.rs index 0f12d03107..a94a3c0b5d 100644 --- a/neqo-transport/src/connection/tests/recovery.rs +++ b/neqo-transport/src/connection/tests/recovery.rs @@ -20,11 +20,11 @@ use super::{ super::{Connection, ConnectionParameters, Output, State}, assert_full_cwnd, connect, connect_force_idle, connect_rtt_idle, connect_with_rtt, cwnd, default_client, default_server, fill_cwnd, maybe_authenticate, new_client, send_and_receive, - send_something, AT_LEAST_PTO, DEFAULT_RTT, DEFAULT_STREAM_DATA, POST_HANDSHAKE_CWND, + send_something, AT_LEAST_PTO, DEFAULT_ADDR, DEFAULT_RTT, DEFAULT_STREAM_DATA, + POST_HANDSHAKE_CWND, }; use crate::{ - cc::CWND_MIN, - path::PATH_MTU_V6, + connection::tests::cwnd_min, recovery::{ FAST_PTO_SCALE, MAX_OUTSTANDING_UNACK, MAX_PTO_PACKET_COUNT, MIN_OUTSTANDING_UNACK, }, @@ -32,7 +32,7 @@ use crate::{ stats::MAX_PTO_COUNTS, tparams::TransportParameter, tracking::DEFAULT_ACK_DELAY, - StreamType, + Pmtud, StreamType, }; #[test] @@ -82,14 +82,14 @@ fn pto_works_full_cwnd() { // Send lots of data. let stream_id = client.stream_create(StreamType::UniDi).unwrap(); let (dgrams, now) = fill_cwnd(&mut client, stream_id, now); - assert_full_cwnd(&dgrams, POST_HANDSHAKE_CWND); + assert_full_cwnd(&dgrams, POST_HANDSHAKE_CWND, client.plpmtu()); // Fill the CWND after waiting for a PTO. let (dgrams, now) = fill_cwnd(&mut client, stream_id, now + AT_LEAST_PTO); // Two packets in the PTO. // The first should be full sized; the second might be small. assert_eq!(dgrams.len(), 2); - assert_eq!(dgrams[0].len(), PATH_MTU_V6); + assert_eq!(dgrams[0].len(), client.plpmtu()); // Both datagrams contain one or more STREAM frames. for d in dgrams { @@ -168,7 +168,7 @@ fn pto_initial() { let mut client = default_client(); let pkt1 = client.process(None, now).dgram(); assert!(pkt1.is_some()); - assert_eq!(pkt1.clone().unwrap().len(), PATH_MTU_V6); + assert_eq!(pkt1.clone().unwrap().len(), client.plpmtu()); let delay = client.process(None, now).callback(); assert_eq!(delay, INITIAL_PTO); @@ -177,7 +177,7 @@ fn pto_initial() { now += delay; let pkt2 = client.process(None, now).dgram(); assert!(pkt2.is_some()); - assert_eq!(pkt2.unwrap().len(), PATH_MTU_V6); + assert_eq!(pkt2.unwrap().len(), client.plpmtu()); let delay = client.process(None, now).callback(); // PTO has doubled. @@ -382,7 +382,7 @@ fn handshake_ack_pto() { let mut server = default_server(); // This is a greasing transport parameter, and large enough that the // server needs to send two Handshake packets. - let big = TransportParameter::Bytes(vec![0; PATH_MTU_V6]); + let big = TransportParameter::Bytes(vec![0; Pmtud::default_plpmtu(DEFAULT_ADDR.ip())]); server.set_local_tparam(0xce16, big).unwrap(); let c1 = client.process(None, now).dgram(); @@ -800,5 +800,5 @@ fn fast_pto_persistent_congestion() { let ack = server.process(Some(&dgram), now).dgram(); now += DEFAULT_RTT / 2; client.process_input(&ack.unwrap(), now); - assert_eq!(cwnd(&client), CWND_MIN); + assert_eq!(cwnd(&client), cwnd_min(&client)); } diff --git a/neqo-transport/src/crypto.rs b/neqo-transport/src/crypto.rs index a9a76ec916..ea44f2e834 100644 --- a/neqo-transport/src/crypto.rs +++ b/neqo-transport/src/crypto.rs @@ -430,8 +430,12 @@ pub struct CryptoDxState { /// The total number of operations that are remaining before the keys /// become exhausted and can't be used any more. invocations: PacketNumber, + /// The basis of the invocation limits in `invocations`. + largest_packet_len: usize, } +const INITIAL_LARGEST_PACKET_LEN: usize = 1 << 11; // 2048 + impl CryptoDxState { #[allow(clippy::reversed_empty_ranges)] // To initialize an empty range. pub fn new( @@ -458,6 +462,7 @@ impl CryptoDxState { used_pn: 0..0, min_pn: 0, invocations: Self::limit(direction, cipher), + largest_packet_len: INITIAL_LARGEST_PACKET_LEN, } } @@ -551,6 +556,7 @@ impl CryptoDxState { used_pn: pn..pn, min_pn: pn, invocations, + largest_packet_len: INITIAL_LARGEST_PACKET_LEN, } } @@ -645,10 +651,15 @@ impl CryptoDxState { hex(hdr), hex(body) ); - // The numbers in `Self::limit` assume a maximum packet size of 2^11. - if body.len() > 2048 { - debug_assert!(false); - return Err(Error::InternalError); + + // The numbers in `Self::limit` assume a maximum packet size of `LIMIT`. + // Adjust them as we encounter larger packets. + debug_assert!(body.len() < 65536); + if body.len() > self.largest_packet_len { + let new_bits = usize::leading_zeros(self.largest_packet_len - 1) + - usize::leading_zeros(body.len() - 1); + self.invocations >>= new_bits; + self.largest_packet_len = body.len(); } self.invoked()?; @@ -1295,6 +1306,7 @@ impl CryptoStates { used_pn: 0..645_971_972, min_pn: 0, invocations: 10, + largest_packet_len: INITIAL_LARGEST_PACKET_LEN, }, cipher: TLS_CHACHA20_POLY1305_SHA256, next_secret: secret.clone(), diff --git a/neqo-transport/src/lib.rs b/neqo-transport/src/lib.rs index c16030d694..541f851155 100644 --- a/neqo-transport/src/lib.rs +++ b/neqo-transport/src/lib.rs @@ -28,6 +28,7 @@ pub mod packet; #[cfg(not(fuzzing))] mod packet; mod path; +mod pmtud; mod qlog; mod quic_datagrams; mod recovery; @@ -62,6 +63,7 @@ pub use self::{ events::{ConnectionEvent, ConnectionEvents}, frame::CloseError, packet::MIN_INITIAL_PACKET_SIZE, + pmtud::Pmtud, quic_datagrams::DatagramTracking, recv_stream::{RecvStreamStats, RECV_BUFFER_SIZE}, send_stream::{SendStreamStats, SEND_BUFFER_SIZE}, diff --git a/neqo-transport/src/pace.rs b/neqo-transport/src/pace.rs index 5b88e5c0c4..d34d015ab1 100644 --- a/neqo-transport/src/pace.rs +++ b/neqo-transport/src/pace.rs @@ -60,6 +60,14 @@ impl Pacer { } } + pub const fn mtu(&self) -> usize { + self.p + } + + pub fn set_mtu(&mut self, mtu: usize) { + self.p = mtu; + } + /// Determine when the next packet will be available based on the provided RTT /// and congestion window. This doesn't update state. /// This returns a time, which could be in the past (this object doesn't know what diff --git a/neqo-transport/src/packet/mod.rs b/neqo-transport/src/packet/mod.rs index dd1bc225c5..6ac257fabf 100644 --- a/neqo-transport/src/packet/mod.rs +++ b/neqo-transport/src/packet/mod.rs @@ -19,8 +19,9 @@ use crate::{ cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdRef, MAX_CONNECTION_ID_LEN}, crypto::{CryptoDxState, CryptoSpace, CryptoStates}, frame::FRAME_TYPE_PADDING, + recovery::SendProfile, version::{Version, WireVersion}, - Error, Res, + Error, Pmtud, Res, }; /// `MIN_INITIAL_PACKET_SIZE` is the smallest packet that can be used to establish @@ -229,6 +230,24 @@ impl PacketBuilder { self.limit = limit; } + /// Set the initial limit for the packet, based on the profile and the PMTUD state. + /// Returns true if the packet needs padding. + pub fn set_initial_limit( + &mut self, + profile: &SendProfile, + aead_expansion: usize, + pmtud: &Pmtud, + ) -> bool { + if pmtud.needs_probe() { + debug_assert!(pmtud.probe_size() > profile.limit()); + self.limit = pmtud.probe_size() - aead_expansion; + true + } else { + self.limit = profile.limit() - aead_expansion; + false + } + } + /// Get the current limit. #[must_use] pub const fn limit(&self) -> usize { diff --git a/neqo-transport/src/path.rs b/neqo-transport/src/path.rs index cee105a5ed..de4037b495 100644 --- a/neqo-transport/src/path.rs +++ b/neqo-transport/src/path.rs @@ -10,7 +10,7 @@ use std::{ cell::RefCell, fmt::{self, Display}, mem, - net::{IpAddr, SocketAddr}, + net::SocketAddr, rc::Rc, time::{Duration, Instant}, }; @@ -25,6 +25,7 @@ use crate::{ ecn::{EcnCount, EcnInfo}, frame::{FRAME_TYPE_PATH_CHALLENGE, FRAME_TYPE_PATH_RESPONSE, FRAME_TYPE_RETIRE_CONNECTION_ID}, packet::PacketBuilder, + pmtud::Pmtud, recovery::{RecoveryToken, SentPacket}, rtt::RttEstimate, sender::PacketSender, @@ -33,14 +34,6 @@ use crate::{ Stats, }; -/// This is the MTU that we assume when using IPv6. -/// We use this size for Initial packets, so we don't need to worry about probing for support. -/// If the path doesn't support this MTU, we will assume that it doesn't support QUIC. -/// -/// This is a multiple of 16 greater than the largest possible short header (1 + 20 + 4). -pub const PATH_MTU_V6: usize = 1337; -/// The path MTU for IPv4 can be 20 bytes larger than for v6. -pub const PATH_MTU_V4: usize = PATH_MTU_V6 + 20; /// The number of times that a path will be probed before it is considered failed. const MAX_PATH_PROBES: usize = 3; /// The maximum number of paths that `Paths` will track. @@ -291,6 +284,10 @@ impl Paths { false } } else { + // See if the PMTUD raise timer wants to fire. + if let Some(path) = self.primary() { + path.borrow_mut().pmtud_mut().maybe_fire_raise_timer(now); + } true } } @@ -558,7 +555,7 @@ impl Path { qlog: NeqoQlog, now: Instant, ) -> Self { - let mut sender = PacketSender::new(cc, pacing, Self::mtu_by_addr(remote.ip()), now); + let mut sender = PacketSender::new(cc, pacing, Pmtud::new(remote.ip()), now); sender.set_qlog(qlog.clone()); Self { local, @@ -652,16 +649,14 @@ impl Path { } } - const fn mtu_by_addr(addr: IpAddr) -> usize { - match addr { - IpAddr::V4(_) => PATH_MTU_V4, - IpAddr::V6(_) => PATH_MTU_V6, - } + /// Get the PL MTU. + pub fn plpmtu(&self) -> usize { + self.pmtud().plpmtu() } - /// Get the path MTU. This is currently fixed based on IP version. - pub const fn mtu(&self) -> usize { - Self::mtu_by_addr(self.remote.ip()) + /// Get a reference to the PMTUD state. + pub fn pmtud(&self) -> &Pmtud { + self.sender.pmtud() } /// Get the first local connection ID. @@ -783,7 +778,6 @@ impl Path { if builder.remaining() < 9 { return false; } - // Send PATH_RESPONSE. let resp_sent = if let Some(challenge) = self.challenge.take() { qtrace!([self], "Responding to path challenge {}", hex(challenge)); @@ -892,6 +886,11 @@ impl Path { &mut self.rtt } + /// Mutably borrow the PMTUD discoverer for this path. + pub fn pmtud_mut(&mut self) -> &mut Pmtud { + self.sender.pmtud_mut() + } + /// Read-only access to the owned sender. pub const fn sender(&self) -> &PacketSender { &self.sender @@ -913,7 +912,7 @@ impl Path { m, ack_ratio, self.sender.cwnd(), - self.mtu(), + self.plpmtu(), self.rtt.estimate(), ) }, @@ -976,6 +975,7 @@ impl Path { acked_pkts: &[SentPacket], ack_ecn: Option, now: Instant, + stats: &mut Stats, ) { debug_assert!(self.is_primary()); @@ -985,11 +985,12 @@ impl Path { .sender .on_ecn_ce_received(acked_pkts.first().expect("must be there")); if cwnd_reduced { - self.rtt.update_ack_delay(self.sender.cwnd(), self.mtu()); + self.rtt.update_ack_delay(self.sender.cwnd(), self.plpmtu()); } } - self.sender.on_packets_acked(acked_pkts, &self.rtt, now); + self.sender + .on_packets_acked(acked_pkts, &self.rtt, now, stats); } /// Record packets as lost with the sender. @@ -998,6 +999,8 @@ impl Path { prev_largest_acked_sent: Option, space: PacketNumberSpace, lost_packets: &[SentPacket], + stats: &mut Stats, + now: Instant, ) { debug_assert!(self.is_primary()); let cwnd_reduced = self.sender.on_packets_lost( @@ -1005,9 +1008,11 @@ impl Path { prev_largest_acked_sent, self.rtt.pto(space), // Important: the base PTO, not adjusted. lost_packets, + stats, + now, ); if cwnd_reduced { - self.rtt.update_ack_delay(self.sender.cwnd(), self.mtu()); + self.rtt.update_ack_delay(self.sender.cwnd(), self.plpmtu()); } } @@ -1025,7 +1030,7 @@ impl Path { // If we have received absolutely nothing thus far, then this endpoint // is the one initiating communication on this path. Allow enough space for // probing. - self.mtu() * 5 + self.plpmtu() * 5 } else { limit }; diff --git a/neqo-transport/src/pmtud.rs b/neqo-transport/src/pmtud.rs new file mode 100644 index 0000000000..5ee59e3dbf --- /dev/null +++ b/neqo-transport/src/pmtud.rs @@ -0,0 +1,688 @@ +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + iter::zip, + net::IpAddr, + time::{Duration, Instant}, +}; + +use neqo_common::{qdebug, qinfo}; +use static_assertions::const_assert; + +use crate::{frame::FRAME_TYPE_PING, packet::PacketBuilder, recovery::SentPacket, Stats}; + +// Values <= 1500 based on: A. Custura, G. Fairhurst and I. Learmonth, "Exploring Usable Path MTU in +// the Internet," 2018 Network Traffic Measurement and Analysis Conference (TMA), Vienna, Austria, +// 2018, pp. 1-8, doi: 10.23919/TMA.2018.8506538. keywords: +// {Servers;Probes;Tools;Clamps;Middleboxes;Standards}, +const MTU_SIZES_V4: &[usize] = &[ + 1280, 1380, 1420, 1472, 1500, 2047, 4095, 8191, 16383, 32767, 65535, +]; +const MTU_SIZES_V6: &[usize] = &[ + 1280, 1380, + 1420, // 1420 is not in the paper for v6, but adding it makes the arrays the same length + 1470, 1500, 2047, 4095, 8191, 16383, 32767, 65535, +]; +const_assert!(MTU_SIZES_V4.len() == MTU_SIZES_V6.len()); +const SEARCH_TABLE_LEN: usize = MTU_SIZES_V4.len(); + +// From https://datatracker.ietf.org/doc/html/rfc8899#section-5.1 +const MAX_PROBES: usize = 3; +const PMTU_RAISE_TIMER: Duration = Duration::from_secs(600); + +#[derive(Debug, PartialEq, Clone, Copy)] +enum Probe { + NotNeeded, + Needed, + Sent, +} + +#[derive(Debug)] +pub struct Pmtud { + search_table: &'static [usize], + header_size: usize, + mtu: usize, + probe_index: usize, + probe_count: usize, + probe_state: Probe, + loss_counts: [usize; SEARCH_TABLE_LEN], + raise_timer: Option, +} + +impl Pmtud { + /// Returns the MTU search table for the given remote IP address family. + const fn search_table(remote_ip: IpAddr) -> &'static [usize] { + match remote_ip { + IpAddr::V4(_) => MTU_SIZES_V4, + IpAddr::V6(_) => MTU_SIZES_V6, + } + } + + /// Size of the IPv4/IPv6 and UDP headers, in bytes. + const fn header_size(remote_ip: IpAddr) -> usize { + match remote_ip { + IpAddr::V4(_) => 20 + 8, + IpAddr::V6(_) => 40 + 8, + } + } + + #[must_use] + pub const fn new(remote_ip: IpAddr) -> Self { + let search_table = Self::search_table(remote_ip); + let probe_index = 0; + Self { + search_table, + header_size: Self::header_size(remote_ip), + mtu: search_table[probe_index], + probe_index, + probe_count: 0, + probe_state: Probe::NotNeeded, + loss_counts: [0; SEARCH_TABLE_LEN], + raise_timer: None, + } + } + + /// Checks whether the PMTUD raise timer should be fired, and does so if needed. + pub fn maybe_fire_raise_timer(&mut self, now: Instant) { + if self.probe_state == Probe::NotNeeded && self.raise_timer.map_or(false, |t| now >= t) { + qdebug!("PMTUD raise timer fired"); + self.raise_timer = None; + self.start(); + } + } + + /// Returns the current Packetization Layer Path MTU, i.e., the maximum UDP payload that can be + /// sent. During probing, this may be smaller than the actual path MTU. + #[must_use] + pub const fn plpmtu(&self) -> usize { + self.mtu - self.header_size + } + + /// Returns true if a PMTUD probe should be sent. + #[must_use] + pub fn needs_probe(&self) -> bool { + self.probe_state == Probe::Needed + } + + /// Returns true if a PMTUD probe was sent. + #[must_use] + pub fn probe_sent(&self) -> bool { + self.probe_state == Probe::Sent + } + + /// Returns the size of the current PMTUD probe. + #[must_use] + pub const fn probe_size(&self) -> usize { + self.search_table[self.probe_index] - self.header_size + } + + /// Sends a PMTUD probe. + pub fn send_probe(&mut self, builder: &mut PacketBuilder, stats: &mut Stats) { + // The packet may include ACK-eliciting data already, but rather than check for that, it + // seems OK to burn one byte here to simply include a PING. + builder.encode_varint(FRAME_TYPE_PING); + stats.frame_tx.ping += 1; + stats.frame_tx.all += 1; + stats.pmtud_tx += 1; + self.probe_count += 1; + self.probe_state = Probe::Sent; + qdebug!( + "Sending PMTUD probe of size {}, count {}", + self.search_table[self.probe_index], + self.probe_count + ); + } + + #[allow(rustdoc::private_intra_doc_links)] + /// Provides a [`Fn`] that returns true if the packet is a PMTUD probe. + /// + /// Allows filtering packets without holding a reference to [`Pmtud`]. When + /// in doubt, use [`Pmtud::is_probe`]. + pub fn is_probe_filter(&self) -> impl Fn(&SentPacket) -> bool { + let probe_state = self.probe_state; + let probe_size = self.probe_size(); + + move |p: &SentPacket| -> bool { probe_state == Probe::Sent && p.len() == probe_size } + } + + /// Returns true if the packet is a PMTUD probe. + fn is_probe(&self, p: &SentPacket) -> bool { + self.is_probe_filter()(p) + } + + /// Count the PMTUD probes included in `pkts`. + fn count_probes(&self, pkts: &[SentPacket]) -> usize { + pkts.iter().filter(|p| self.is_probe(p)).count() + } + + /// Checks whether a PMTUD probe has been acknowledged, and if so, updates the PMTUD state. + /// May also initiate a new probe process for a larger MTU. + pub fn on_packets_acked(&mut self, acked_pkts: &[SentPacket], stats: &mut Stats) { + // Reset the loss counts for all packets sizes <= the size of the largest ACKed packet. + let max_len = acked_pkts.iter().map(SentPacket::len).max().unwrap_or(0); + if max_len == 0 { + // No packets were ACKed, nothing to do. + return; + } + + let idx = self + .search_table + .iter() + .position(|&sz| sz > max_len + self.header_size) + .unwrap_or(SEARCH_TABLE_LEN); + self.loss_counts.iter_mut().take(idx).for_each(|c| *c = 0); + + let acked = self.count_probes(acked_pkts); + if acked == 0 { + return; + } + + // A probe was ACKed, confirm the new MTU and try to probe upwards further. + // + // TODO: Maybe we should be tracking stats on a per-probe-size basis rather than just the + // total number of successful probes. + stats.pmtud_ack += acked; + self.mtu = self.search_table[self.probe_index]; + qdebug!("PMTUD probe of size {} succeeded", self.mtu); + self.start(); + } + + /// Stops the PMTUD process, setting the MTU to the largest successful probe size. + fn stop(&mut self, idx: usize, now: Instant) { + self.probe_state = Probe::NotNeeded; // We don't need to send any more probes + self.probe_index = idx; // Index of the last successful probe + self.mtu = self.search_table[idx]; // Leading to this MTU + self.probe_count = 0; // Reset the count + self.loss_counts.fill(0); // Reset the loss counts + self.raise_timer = Some(now + PMTU_RAISE_TIMER); + qinfo!( + "PMTUD stopped, PLPMTU is now {}, raise timer {:?}", + self.mtu, + self.raise_timer.unwrap() + ); + } + + /// Checks whether a PMTUD probe has been lost. If it has been lost more than `MAX_PROBES` + /// times, the PMTUD process is stopped. + pub fn on_packets_lost( + &mut self, + lost_packets: &[SentPacket], + stats: &mut Stats, + now: Instant, + ) { + if lost_packets.is_empty() { + return; + } + + let mut increase = [0; SEARCH_TABLE_LEN]; + let mut loss_counts_updated = false; + for p in lost_packets { + let Some(idx) = self + .search_table + .iter() + .position(|&sz| p.len() <= sz - self.header_size) + else { + continue; + }; + // Count each lost packet size <= the current MTU only once. Otherwise a burst loss of + // >= MAX_PROBES MTU-sized packets triggers a PMTUD restart. Counting only one of them + // here requires three consecutive loss instances of such sizes to trigger a PMTUD + // restart. + // + // Also, ignore losses of packets <= the minimum QUIC packet size, (`searchtable[0]`), + // since they just increase loss counts across the board, adding to spurious + // PMTUD restarts. + if idx > 0 && (increase[idx] == 0 || p.len() > self.plpmtu()) { + loss_counts_updated = true; + increase[idx] += 1; + } + } + + if !loss_counts_updated { + return; + } + + let mut accum = 0; + for (c, incr) in zip(&mut self.loss_counts, increase) { + accum += incr; + *c += accum; + } + + // Track lost probes + let lost = self.count_probes(lost_packets); + stats.pmtud_lost += lost; + + // Check if any packet sizes have been lost MAX_PROBES times or more. + // + // TODO: It's not clear that MAX_PROBES is the right number for losses of packets that + // aren't PMTU probes. We might want to be more conservative, to avoid spurious PMTUD + // restarts. + let Some(first_failed) = self.loss_counts.iter().position(|&c| c >= MAX_PROBES) else { + // If not, keep going. + if lost > 0 { + // Don't stop the PMTUD process. + self.probe_state = Probe::Needed; + } + return; + }; + + let last_ok = first_failed - 1; + qdebug!( + "Packet of size > {} lost >= {} times", + self.search_table[last_ok], + MAX_PROBES + ); + if self.probe_state == Probe::NotNeeded { + // We saw multiple losses of packets <= the current MTU outside of PMTU discovery, + // so we need to probe again. To limit connectivity disruptions, we start the PMTU + // discovery from the smallest packet up, rather than the failed packet size down. + // + // TODO: If we are declaring losses, that means that we're getting packets through. + // The size of those will put a floor on the MTU. We're currently conservative and + // start from scratch, but we don't strictly need to do that. + self.restart(stats); + } else { + // We saw multiple losses of packets > the current MTU during PMTU discovery, so + // we're done. + self.stop(last_ok, now); + } + } + + fn restart(&mut self, stats: &mut Stats) { + self.probe_index = 0; + self.mtu = self.search_table[self.probe_index]; + self.loss_counts.fill(0); + self.raise_timer = None; + stats.pmtud_change += 1; + qdebug!("PMTUD restarted, PLPMTU is now {}", self.mtu); + self.start(); + } + + /// Starts the next upward PMTUD probe. + pub fn start(&mut self) { + if self.probe_index < SEARCH_TABLE_LEN - 1 { + self.probe_state = Probe::Needed; // We need to send a probe + self.probe_count = 0; // For the first time + self.probe_index += 1; // At this size + qdebug!( + "PMTUD started with probe size {}", + self.search_table[self.probe_index], + ); + } else { + // If we're at the end of the search table, we're done. + self.probe_state = Probe::NotNeeded; + } + } + + /// Returns the default PLPMTU for the given remote IP address. + #[must_use] + pub const fn default_plpmtu(remote_ip: IpAddr) -> usize { + let search_table = Self::search_table(remote_ip); + search_table[0] - Self::header_size(remote_ip) + } +} + +#[cfg(all(not(feature = "disable-encryption"), test))] +mod tests { + use std::{ + iter::zip, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + time::Instant, + }; + + use neqo_common::{qdebug, Encoder, IpTosEcn}; + use test_fixture::{fixture_init, now}; + + use crate::{ + crypto::CryptoDxState, + packet::{PacketBuilder, PacketType}, + pmtud::{Probe, PMTU_RAISE_TIMER, SEARCH_TABLE_LEN}, + recovery::{SendProfile, SentPacket}, + Pmtud, Stats, + }; + + const V4: IpAddr = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)); + const V6: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)); + + fn make_sentpacket(pn: u64, now: Instant, len: usize) -> SentPacket { + SentPacket::new( + PacketType::Short, + pn, + IpTosEcn::default(), + now, + true, + Vec::new(), + len, + ) + } + + fn assert_mtu(pmtud: &Pmtud, mtu: usize) { + let idx = pmtud + .search_table + .iter() + .position(|x| *x == pmtud.mtu) + .unwrap(); + assert!(mtu >= pmtud.search_table[idx]); + if idx < SEARCH_TABLE_LEN - 1 { + assert!(mtu < pmtud.search_table[idx + 1]); + } + } + + fn pmtud_step( + pmtud: &mut Pmtud, + stats: &mut Stats, + prot: &mut CryptoDxState, + addr: IpAddr, + mtu: usize, + now: Instant, + ) { + let stats_before = stats.clone(); + + // Fake a packet number, so the builder logic works. + let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let pn = prot.next_pn(); + builder.pn(pn, 4); + builder.set_initial_limit(&SendProfile::new_limited(pmtud.plpmtu()), 16, pmtud); + builder.enable_padding(true); + pmtud.send_probe(&mut builder, stats); + builder.pad(); + let encoder = builder.build(prot).unwrap(); + assert_eq!(encoder.len(), pmtud.probe_size()); + assert!(!pmtud.needs_probe()); + assert_eq!(stats_before.pmtud_tx + 1, stats.pmtud_tx); + + let packet = make_sentpacket(pn, now, encoder.len()); + if encoder.len() + Pmtud::header_size(addr) <= mtu { + pmtud.on_packets_acked(&[packet], stats); + assert_eq!(stats_before.pmtud_ack + 1, stats.pmtud_ack); + } else { + pmtud.on_packets_lost(&[packet], stats, now); + assert_eq!(stats_before.pmtud_lost + 1, stats.pmtud_lost); + } + } + + fn find_pmtu(addr: IpAddr, mtu: usize) { + fixture_init(); + let now = now(); + let mut pmtud = Pmtud::new(addr); + let mut stats = Stats::default(); + let mut prot = CryptoDxState::test_default(); + + pmtud.start(); + assert!(pmtud.needs_probe()); + + while pmtud.needs_probe() { + pmtud_step(&mut pmtud, &mut stats, &mut prot, addr, mtu, now); + } + assert_mtu(&pmtud, mtu); + } + + #[test] + fn pmtud_v4_max() { + find_pmtu(V4, u16::MAX.into()); + } + + #[test] + fn pmtud_v6_max() { + find_pmtu(V6, u16::MAX.into()); + } + + #[test] + fn pmtud_v4_1500() { + find_pmtu(V4, 1500); + } + + #[test] + fn pmtud_v6_1500() { + find_pmtu(V6, 1500); + } + + fn find_pmtu_with_reduction(addr: IpAddr, mtu: usize, smaller_mtu: usize) { + assert!(mtu > smaller_mtu); + + fixture_init(); + let now = now(); + let mut pmtud = Pmtud::new(addr); + let mut stats = Stats::default(); + let mut prot = CryptoDxState::test_default(); + + assert!(smaller_mtu >= pmtud.search_table[0]); + pmtud.start(); + assert!(pmtud.needs_probe()); + + while pmtud.needs_probe() { + pmtud_step(&mut pmtud, &mut stats, &mut prot, addr, mtu, now); + } + assert_mtu(&pmtud, mtu); + + qdebug!("Reducing MTU to {}", smaller_mtu); + // Drop packets > smaller_mtu until we need a probe again. + while !pmtud.needs_probe() { + let pn = prot.next_pn(); + let packet = make_sentpacket(pn, now, pmtud.mtu - pmtud.header_size); + pmtud.on_packets_lost(&[packet], &mut stats, now); + } + + // Drive second PMTUD process to completion. + while pmtud.needs_probe() { + pmtud_step(&mut pmtud, &mut stats, &mut prot, addr, mtu, now); + } + assert_mtu(&pmtud, mtu); + } + + #[test] + fn pmtud_v4_max_1300() { + find_pmtu_with_reduction(V4, u16::MAX.into(), 1300); + } + + #[test] + fn pmtud_v6_max_1280() { + find_pmtu_with_reduction(V6, u16::MAX.into(), 1300); + } + + #[test] + fn pmtud_v4_1500_1300() { + find_pmtu_with_reduction(V4, 1500, 1300); + } + + #[test] + fn pmtud_v6_1500_1280() { + find_pmtu_with_reduction(V6, 1500, 1280); + } + + fn find_pmtu_with_increase(addr: IpAddr, mtu: usize, larger_mtu: usize) { + assert!(mtu < larger_mtu); + + fixture_init(); + let now = now(); + let mut pmtud = Pmtud::new(addr); + let mut stats = Stats::default(); + let mut prot = CryptoDxState::test_default(); + + assert!(larger_mtu >= pmtud.search_table[0]); + pmtud.start(); + assert!(pmtud.needs_probe()); + + while pmtud.needs_probe() { + pmtud_step(&mut pmtud, &mut stats, &mut prot, addr, mtu, now); + } + assert_mtu(&pmtud, mtu); + + qdebug!("Increasing MTU to {}", larger_mtu); + let now = now + PMTU_RAISE_TIMER; + pmtud.maybe_fire_raise_timer(now); + while pmtud.needs_probe() { + pmtud_step(&mut pmtud, &mut stats, &mut prot, addr, larger_mtu, now); + } + assert_mtu(&pmtud, larger_mtu); + } + + #[test] + fn pmtud_v4_1300_max() { + find_pmtu_with_increase(V4, 1300, u16::MAX.into()); + } + + #[test] + fn pmtud_v6_1280_max() { + find_pmtu_with_increase(V6, 1280, u16::MAX.into()); + } + + #[test] + fn pmtud_v4_1300_1500() { + find_pmtu_with_increase(V4, 1300, 1500); + } + + #[test] + fn pmtud_v6_1280_1500() { + find_pmtu_with_increase(V6, 1280, 1500); + } + + /// Increments the loss counts for the given search table, based on the given packet size. + fn search_table_inc(pmtud: &Pmtud, loss_counts: &[usize], sz: usize) -> Vec { + zip(pmtud.search_table, loss_counts.iter()) + .map(|(&s, &c)| { + if s >= sz + pmtud.header_size { + c + 1 + } else { + c + } + }) + .collect() + } + + /// Asserts that the PMTUD process has restarted. + fn assert_pmtud_restarted(pmtud: &Pmtud) { + assert_eq!(Probe::Needed, pmtud.probe_state); + assert_eq!(pmtud.mtu, pmtud.search_table[0]); + assert_eq!([0; SEARCH_TABLE_LEN], pmtud.loss_counts); + } + + /// Asserts that the PMTUD process has stopped at the given MTU. + fn assert_pmtud_stopped(pmtud: &Pmtud, mtu: usize) { + // assert_eq!(Probe::NotNeeded, pmtud.probe_state); + assert_eq!(pmtud.mtu, mtu); + assert_eq!([0; SEARCH_TABLE_LEN], pmtud.loss_counts); + } + + #[test] + fn pmtud_on_packets_lost() { + let now = now(); + let mut pmtud = Pmtud::new(V4); + let mut stats = Stats::default(); + + // No packets lost, nothing should change. + pmtud.on_packets_lost(&[], &mut stats, now); + assert_eq!([0; SEARCH_TABLE_LEN], pmtud.loss_counts); + + // A packet of size 100 was lost, which is smaller than all probe sizes. + // Loss counts should be unchanged. + pmtud.on_packets_lost(&[make_sentpacket(0, now, 100)], &mut stats, now); + assert_eq!([0; SEARCH_TABLE_LEN], pmtud.loss_counts); + + // A packet of size 100_000 was lost, which is larger than all probe sizes. + // Loss counts should be unchanged. + pmtud.on_packets_lost(&[make_sentpacket(0, now, 100_000)], &mut stats, now); + assert_eq!([0; SEARCH_TABLE_LEN], pmtud.loss_counts); + + pmtud.loss_counts.fill(0); // Reset the loss counts. + + // A packet of size 1500 was lost, which should increase loss counts >= 1500 by one. + let plen = 1500 - pmtud.header_size; + let mut expected_lc = search_table_inc(&pmtud, &pmtud.loss_counts, plen); + pmtud.on_packets_lost(&[make_sentpacket(0, now, plen)], &mut stats, now); + assert_eq!(expected_lc, pmtud.loss_counts); + + // A packet of size 2000 was lost, which should increase loss counts >= 2000 by one. + expected_lc = search_table_inc(&pmtud, &expected_lc, 2000); + pmtud.on_packets_lost(&[make_sentpacket(0, now, 2000)], &mut stats, now); + assert_eq!(expected_lc, pmtud.loss_counts); + + // A packet of size 5000 was lost, which should increase loss counts >= 5000 by one. There + // have now been MAX_PROBES losses of packets >= 5000, so the PMTUD process should have + // restarted. + pmtud.on_packets_lost(&[make_sentpacket(0, now, 5000)], &mut stats, now); + assert_pmtud_restarted(&pmtud); + expected_lc.fill(0); // Reset the expected loss counts. + + // Two packets of size 4000 were lost, which should increase loss counts >= 4000 by two. + let expected_lc = search_table_inc(&pmtud, &expected_lc, 4000); + let expected_lc = search_table_inc(&pmtud, &expected_lc, 4000); + pmtud.on_packets_lost( + &[make_sentpacket(0, now, 4000), make_sentpacket(0, now, 4000)], + &mut stats, + now, + ); + assert_eq!(expected_lc, pmtud.loss_counts); + + // A packet of size 2000 was lost, which should increase loss counts >= 2000 by one. There + // have now been MAX_PROBES losses of packets >= 4000, so the PMTUD process should have + // stopped. + pmtud.on_packets_lost( + &[make_sentpacket(0, now, 2000), make_sentpacket(0, now, 2000)], + &mut stats, + now, + ); + assert_pmtud_stopped(&pmtud, 2047); + } + + /// Zeros the loss counts for the given search table, below the given packet size. + fn search_table_zero(pmtud: &Pmtud, loss_counts: &[usize], sz: usize) -> Vec { + zip(pmtud.search_table, loss_counts.iter()) + .map(|(&s, &c)| if s <= sz + pmtud.header_size { 0 } else { c }) + .collect() + } + + #[test] + fn pmtud_on_packets_lost_and_acked() { + let now = now(); + let mut pmtud = Pmtud::new(V4); + let mut stats = Stats::default(); + + // A packet of size 100 was ACKed, which is smaller than all probe sizes. + // Loss counts should be unchanged. + pmtud.on_packets_acked(&[make_sentpacket(0, now, 100)], &mut stats); + assert_eq!([0; SEARCH_TABLE_LEN], pmtud.loss_counts); + + // A packet of size 100_000 was ACKed, which is larger than all probe sizes. + // Loss counts should be unchanged. + pmtud.on_packets_acked(&[make_sentpacket(0, now, 100_000)], &mut stats); + assert_eq!([0; SEARCH_TABLE_LEN], pmtud.loss_counts); + + pmtud.loss_counts.fill(0); // Reset the loss counts. + + // No packets ACKed, nothing should change. + pmtud.on_packets_acked(&[], &mut stats); + assert_eq!([0; SEARCH_TABLE_LEN], pmtud.loss_counts); + + // One packet of size 4000 was lost, which should increase loss counts >= 4000 by one. + let expected_lc = search_table_inc(&pmtud, &pmtud.loss_counts, 4000); + pmtud.on_packets_lost(&[make_sentpacket(0, now, 4000)], &mut stats, now); + assert_eq!(expected_lc, pmtud.loss_counts); + + // Now a packet of size 5000 is ACKed, which should reset all loss counts <= 5000. + pmtud.on_packets_acked(&[make_sentpacket(0, now, 5000)], &mut stats); + let expected_lc = search_table_zero(&pmtud, &pmtud.loss_counts, 5000); + assert_eq!(expected_lc, pmtud.loss_counts); + + // Now, one more packets of size 4000 was lost, which should increase loss counts >= 4000 + // by one. + let expected_lc = search_table_inc(&pmtud, &expected_lc, 4000); + pmtud.on_packets_lost(&[make_sentpacket(0, now, 4000)], &mut stats, now); + assert_eq!(expected_lc, pmtud.loss_counts); + + // Now a packet of size 8000 is ACKed, which should reset all loss counts <= 8000. + pmtud.on_packets_acked(&[make_sentpacket(0, now, 8000)], &mut stats); + let expected_lc = search_table_zero(&pmtud, &pmtud.loss_counts, 8000); + assert_eq!(expected_lc, pmtud.loss_counts); + + // Now, one more packets of size 9000 was lost, which should increase loss counts >= 9000 + // by one. There have now been MAX_PROBES losses of packets >= 8191, so the PMTUD process + // should have restarted. + pmtud.on_packets_lost(&[make_sentpacket(0, now, 9000)], &mut stats, now); + assert_pmtud_restarted(&pmtud); + } +} diff --git a/neqo-transport/src/recovery/mod.rs b/neqo-transport/src/recovery/mod.rs index bec3664118..e697e78695 100644 --- a/neqo-transport/src/recovery/mod.rs +++ b/neqo-transport/src/recovery/mod.rs @@ -653,16 +653,23 @@ impl LossRecovery { // Tell the congestion controller about any lost packets. // The PTO for congestion control is the raw number, without exponential // backoff, so that we can determine persistent congestion. - primary_path - .borrow_mut() - .on_packets_lost(prev_largest_acked, pn_space, &lost); + primary_path.borrow_mut().on_packets_lost( + prev_largest_acked, + pn_space, + &lost, + &mut self.stats.borrow_mut(), + now, + ); // This must happen after on_packets_lost. If in recovery, this could // take us out, and then lost packets will start a new recovery period // when it shouldn't. - primary_path - .borrow_mut() - .on_packets_acked(&acked_packets, ack_ecn, now); + primary_path.borrow_mut().on_packets_acked( + &acked_packets, + ack_ecn, + now, + &mut self.stats.borrow_mut(), + ); self.pto_state = None; @@ -880,6 +887,8 @@ impl LossRecovery { space.largest_acked_sent_time, space.space(), &lost_packets[first..], + &mut self.stats.borrow_mut(), + now, ); } self.stats.borrow_mut().lost += lost_packets.len(); @@ -894,7 +903,7 @@ impl LossRecovery { pub fn send_profile(&mut self, path: &Path, now: Instant) -> SendProfile { qdebug!([self], "get send profile {:?}", now); let sender = path.sender(); - let mtu = path.mtu(); + let mtu = path.plpmtu(); if let Some(profile) = self .pto_state .as_mut() diff --git a/neqo-transport/src/sender.rs b/neqo-transport/src/sender.rs index 22abef4dc7..a9ead627aa 100644 --- a/neqo-transport/src/sender.rs +++ b/neqo-transport/src/sender.rs @@ -9,17 +9,19 @@ #![allow(clippy::module_name_repetitions)] use std::{ - fmt::{self, Debug, Display}, + fmt::{self, Display}, time::{Duration, Instant}, }; -use neqo_common::qlog::NeqoQlog; +use neqo_common::{qdebug, qlog::NeqoQlog}; use crate::{ cc::{ClassicCongestionControl, CongestionControl, CongestionControlAlgorithm, Cubic, NewReno}, pace::Pacer, + pmtud::Pmtud, recovery::SentPacket, rtt::RttEstimate, + Stats, }; /// The number of packets we allow to burst from the pacer. @@ -42,16 +44,17 @@ impl PacketSender { pub fn new( alg: CongestionControlAlgorithm, pacing_enabled: bool, - mtu: usize, + pmtud: Pmtud, now: Instant, ) -> Self { + let mtu = pmtud.plpmtu(); Self { cc: match alg { CongestionControlAlgorithm::NewReno => { - Box::new(ClassicCongestionControl::new(NewReno::default())) + Box::new(ClassicCongestionControl::new(NewReno::default(), pmtud)) } CongestionControlAlgorithm::Cubic => { - Box::new(ClassicCongestionControl::new(Cubic::default())) + Box::new(ClassicCongestionControl::new(Cubic::default(), pmtud)) } }, pacer: Pacer::new(pacing_enabled, now, mtu * PACING_BURST_SIZE, mtu), @@ -62,6 +65,14 @@ impl PacketSender { self.cc.set_qlog(qlog); } + pub fn pmtud(&self) -> &Pmtud { + self.cc.pmtud() + } + + pub fn pmtud_mut(&mut self) -> &mut Pmtud { + self.cc.pmtud_mut() + } + #[must_use] pub fn cwnd(&self) -> usize { self.cc.cwnd() @@ -72,13 +83,34 @@ impl PacketSender { self.cc.cwnd_avail() } + #[cfg(test)] + #[must_use] + pub fn cwnd_min(&self) -> usize { + self.cc.cwnd_min() + } + + fn maybe_update_pacer_mtu(&mut self) { + let current_mtu = self.pmtud().plpmtu(); + if current_mtu != self.pacer.mtu() { + qdebug!( + "PLPMTU changed from {} to {}, updating pacer", + self.pacer.mtu(), + current_mtu + ); + self.pacer.set_mtu(current_mtu); + } + } + pub fn on_packets_acked( &mut self, acked_pkts: &[SentPacket], rtt_est: &RttEstimate, now: Instant, + stats: &mut Stats, ) { self.cc.on_packets_acked(acked_pkts, rtt_est, now); + self.pmtud_mut().on_packets_acked(acked_pkts, stats); + self.maybe_update_pacer_mtu(); } /// Called when packets are lost. Returns true if the congestion window was reduced. @@ -88,13 +120,20 @@ impl PacketSender { prev_largest_acked_sent: Option, pto: Duration, lost_packets: &[SentPacket], + stats: &mut Stats, + now: Instant, ) -> bool { - self.cc.on_packets_lost( + let ret = self.cc.on_packets_lost( first_rtt_sample_time, prev_largest_acked_sent, pto, lost_packets, - ) + ); + // Call below may change the size of MTU probes, so it needs to happen after the CC + // reaction above, which needs to ignore probes based on their size. + self.pmtud_mut().on_packets_lost(lost_packets, stats, now); + self.maybe_update_pacer_mtu(); + ret } /// Called when ECN CE mark received. Returns true if the congestion window was reduced. diff --git a/neqo-transport/src/stats.rs b/neqo-transport/src/stats.rs index 0c4b604671..b7342227a4 100644 --- a/neqo-transport/src/stats.rs +++ b/neqo-transport/src/stats.rs @@ -134,6 +134,14 @@ pub struct Stats { /// Acknowledgments for packets that contained data that was marked /// for retransmission when the PTO timer popped. pub pto_ack: usize, + /// Number of PMTUD probes sent. + pub pmtud_tx: usize, + /// Number of PMTUD probes ACK'ed. + pub pmtud_ack: usize, + /// Number of PMTUD probes lost. + pub pmtud_lost: usize, + /// Number of times a path MTU changed unexpectedly. + pub pmtud_change: usize, /// Whether the connection was resumed successfully. pub resumed: bool, @@ -206,6 +214,11 @@ impl Debug for Stats { " tx: {} lost {} lateack {} ptoack {}", self.packets_tx, self.lost, self.late_ack, self.pto_ack )?; + writeln!( + f, + " pmtud: {} sent {} acked {} lost {} change", + self.pmtud_tx, self.pmtud_ack, self.pmtud_lost, self.pmtud_change + )?; writeln!(f, " resumed: {}", self.resumed)?; writeln!(f, " frames rx:")?; self.frame_rx.fmt(f)?; diff --git a/test-fixture/src/assertions.rs b/test-fixture/src/assertions.rs index 9e62e7167e..52d0194cbb 100644 --- a/test-fixture/src/assertions.rs +++ b/test-fixture/src/assertions.rs @@ -7,7 +7,7 @@ use std::net::SocketAddr; use neqo_common::{Datagram, Decoder}; -use neqo_transport::{version::WireVersion, Version, MIN_INITIAL_PACKET_SIZE}; +use neqo_transport::{version::WireVersion, Pmtud, Version, MIN_INITIAL_PACKET_SIZE}; use crate::{DEFAULT_ADDR, DEFAULT_ADDR_V4}; @@ -160,7 +160,7 @@ pub fn assert_path(dgram: &Datagram, path_addr: SocketAddr) { pub fn assert_v4_path(dgram: &Datagram, padded: bool) { assert_path(dgram, DEFAULT_ADDR_V4); if padded { - assert_eq!(dgram.len(), 1357 /* PATH_MTU_V4 */); + assert_eq!(dgram.len(), Pmtud::default_plpmtu(DEFAULT_ADDR_V4.ip())); } } @@ -170,6 +170,6 @@ pub fn assert_v4_path(dgram: &Datagram, padded: bool) { pub fn assert_v6_path(dgram: &Datagram, padded: bool) { assert_path(dgram, DEFAULT_ADDR); if padded { - assert_eq!(dgram.len(), 1337 /* PATH_MTU_V6 */); + assert_eq!(dgram.len(), Pmtud::default_plpmtu(DEFAULT_ADDR.ip())); } } diff --git a/test-fixture/src/sim/connection.rs b/test-fixture/src/sim/connection.rs index 58dd4bce23..53b89352f9 100644 --- a/test-fixture/src/sim/connection.rs +++ b/test-fixture/src/sim/connection.rs @@ -81,7 +81,7 @@ impl ConnectionNode { pub fn default_client(goals: impl IntoIterator>) -> Self { Self::new_client( - ConnectionParameters::default(), + ConnectionParameters::default().pmtud(true), boxed![ReachState::new(State::Confirmed)], goals, ) @@ -89,7 +89,7 @@ impl ConnectionNode { pub fn default_server(goals: impl IntoIterator>) -> Self { Self::new_server( - ConnectionParameters::default(), + ConnectionParameters::default().pmtud(true), boxed![ReachState::new(State::Confirmed)], goals, ) diff --git a/test/test.sh b/test/test.sh index 195f8d7297..99286a19d7 100755 --- a/test/test.sh +++ b/test/test.sh @@ -28,7 +28,7 @@ server="SSLKEYLOGFILE=$tmp/test.tlskey ./target/debug/neqo-server $flags $addr:$ tcpdump -U -i "$iface" -w "$tmp/test.pcap" host $addr and port $port >/dev/null 2>&1 & tcpdump_pid=$! -trap 'rm -rf "$tmp"; kill -USR2 $tcpdump_pid' EXIT +trap 'kill $tcpdump_pid; rm -rf "$tmp"' EXIT tmux -CC \ set-option -g default-shell "$(which bash)" \; \