Skip to content

Commit

Permalink
Progress
Browse files Browse the repository at this point in the history
  • Loading branch information
larseggert committed May 14, 2024
1 parent 76f3fd4 commit 3cc307b
Show file tree
Hide file tree
Showing 14 changed files with 246 additions and 183 deletions.
127 changes: 71 additions & 56 deletions neqo-transport/src/cc/classic_cc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@ use ::qlog::events::{quic::CongestionStateUpdated, EventData};
use neqo_common::{const_max, const_min, qdebug, qinfo, qlog::NeqoQlog, qtrace};

pub const CWND_INITIAL_PKTS: usize = 10;
pub const CWND_INITIAL: usize = const_min(
CWND_INITIAL_PKTS * crate::MIN_INITIAL_PACKET_SIZE,
const_max(2 * crate::MIN_INITIAL_PACKET_SIZE, 14720),
);
pub const CWND_MIN: usize = crate::MIN_INITIAL_PACKET_SIZE * 2;
const PERSISTENT_CONG_THRESH: u32 = 3;

pub const fn cwnd_initial(mtu: usize) -> usize {
const_min(CWND_INITIAL_PKTS * mtu, const_max(2 * mtu, 14_720))
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum State {
/// In either slow start or congestion avoidance, not recovery.
Expand Down Expand Up @@ -131,7 +130,7 @@ pub struct ClassicCongestionControl<T> {

impl<T> ClassicCongestionControl<T> {
pub fn max_datagram_size(&self) -> usize {
self.pmtud.max_datagram_size()
self.pmtud.mtu()
}
}

Expand All @@ -156,6 +155,11 @@ impl<T: WindowAdjustment> CongestionControl for ClassicCongestionControl<T> {
self.congestion_window
}

#[must_use]
fn cwnd_initial(&self) -> usize {
cwnd_initial(self.pmtud.mtu())
}

#[must_use]
fn bytes_in_flight(&self) -> usize {
self.bytes_in_flight
Expand All @@ -168,6 +172,11 @@ impl<T: WindowAdjustment> CongestionControl for ClassicCongestionControl<T> {
self.congestion_window.saturating_sub(self.bytes_in_flight)
}

#[must_use]
fn cwnd_min(&self) -> usize {
self.max_datagram_size() * 2
}

// Multi-packet version of OnPacketAckedCC
fn on_packets_acked(&mut self, acked_pkts: &[SentPacket], rtt_est: &RttEstimate, now: Instant) {
let mut is_app_limited = true;
Expand Down Expand Up @@ -378,7 +387,7 @@ impl<T: WindowAdjustment> ClassicCongestionControl<T> {
Self {
cc_algorithm,
state: State::SlowStart,
congestion_window: CWND_INITIAL,
congestion_window: cwnd_initial(pmtud.mtu()),
bytes_in_flight: 0,
acked_bytes: 0,
ssthresh: usize::MAX,
Expand Down Expand Up @@ -477,7 +486,7 @@ impl<T: WindowAdjustment> ClassicCongestionControl<T> {
.expect("time is monotonic");
if elapsed > pc_period {
qinfo!([self], "persistent congestion");
self.congestion_window = CWND_MIN;
self.congestion_window = self.cwnd_min();
self.acked_bytes = 0;
self.set_state(State::PersistentCongestion);
qlog::metrics_updated(
Expand All @@ -504,6 +513,11 @@ impl<T: WindowAdjustment> ClassicCongestionControl<T> {
!self.state.transient() && self.recovery_start.map_or(true, |pn| packet.pn() >= pn)
}

#[must_use]
fn cwnd_min(&self) -> usize {
self.max_datagram_size() * 2
}

/// Handle a congestion event.
/// Returns true if this was a true congestion event.
fn on_congestion_event(&mut self, last_packet: &SentPacket) -> bool {
Expand All @@ -518,7 +532,7 @@ impl<T: WindowAdjustment> ClassicCongestionControl<T> {
self.acked_bytes,
self.max_datagram_size(),
);
self.congestion_window = max(cwnd, CWND_MIN);
self.congestion_window = max(cwnd, self.cwnd_min());
self.acked_bytes = acked_bytes;
self.ssthresh = self.congestion_window;
qdebug!(
Expand Down Expand Up @@ -558,14 +572,15 @@ impl<T: WindowAdjustment> ClassicCongestionControl<T> {

#[cfg(test)]
mod tests {
use std::time::{Duration, Instant};
use std::{
net::{IpAddr, Ipv4Addr},
time::{Duration, Instant},
};

use neqo_common::{qinfo, IpTosEcn};
use test_fixture::now;

use super::{
ClassicCongestionControl, WindowAdjustment, CWND_INITIAL, CWND_MIN, PERSISTENT_CONG_THRESH,
};
use super::{ClassicCongestionControl, WindowAdjustment, PERSISTENT_CONG_THRESH};
use crate::{
cc::{
classic_cc::State,
Expand All @@ -579,6 +594,7 @@ mod tests {
rtt::RttEstimate,
};

const IP_ADDR: IpAddr = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0));
const PTO: Duration = Duration::from_millis(100);
const RTT: Duration = Duration::from_millis(98);
const RTT_ESTIMATE: RttEstimate = RttEstimate::from_duration(Duration::from_millis(98));
Expand All @@ -592,13 +608,13 @@ mod tests {
const PC: Duration = Duration::from_nanos(100_000_000 * (PERSISTENT_CONG_THRESH as u64) + 1);

fn cwnd_is_default(cc: &ClassicCongestionControl<NewReno>) {
assert_eq!(cc.cwnd(), CWND_INITIAL);
assert_eq!(cc.cwnd(), cc.cwnd_initial());
assert_eq!(cc.ssthresh(), usize::MAX);
}

fn cwnd_is_halved(cc: &ClassicCongestionControl<NewReno>) {
assert_eq!(cc.cwnd(), CWND_INITIAL / 2);
assert_eq!(cc.ssthresh(), CWND_INITIAL / 2);
assert_eq!(cc.cwnd(), cc.cwnd_initial() / 2);
assert_eq!(cc.ssthresh(), cc.cwnd_initial() / 2);
}

fn lost(pn: PacketNumber, ack_eliciting: bool, t: Duration) -> SentPacket {
Expand All @@ -617,22 +633,21 @@ mod tests {
match cc {
CongestionControlAlgorithm::NewReno => Box::new(ClassicCongestionControl::new(
NewReno::default(),
PmtudState::new(),
PmtudState::new(IP_ADDR),
)),
CongestionControlAlgorithm::Cubic => Box::new(ClassicCongestionControl::new(
Cubic::default(),
PmtudState::new(),
PmtudState::new(IP_ADDR),
)),
}
}

fn persistent_congestion_by_algorithm(
cc_alg: CongestionControlAlgorithm,
mut cc: Box<dyn CongestionControl>,
reduced_cwnd: usize,
lost_packets: &[SentPacket],
persistent_expected: bool,
) {
let mut cc = congestion_control(cc_alg);
for p in lost_packets {
cc.on_packet_sent(p);
}
Expand All @@ -641,7 +656,7 @@ mod tests {

let persistent = if cc.cwnd() == reduced_cwnd {
false
} else if cc.cwnd() == CWND_MIN {
} else if cc.cwnd() == cc.cwnd_min() {
true
} else {
panic!("unexpected cwnd");
Expand All @@ -650,15 +665,15 @@ mod tests {
}

fn persistent_congestion(lost_packets: &[SentPacket], persistent_expected: bool) {
let cc = congestion_control(CongestionControlAlgorithm::NewReno);
let cwnd_initial = cc.cwnd_initial();
persistent_congestion_by_algorithm(cc, cwnd_initial / 2, lost_packets, persistent_expected);

let cc = congestion_control(CongestionControlAlgorithm::Cubic);
let cwnd_initial = cc.cwnd_initial();
persistent_congestion_by_algorithm(
CongestionControlAlgorithm::NewReno,
CWND_INITIAL / 2,
lost_packets,
persistent_expected,
);
persistent_congestion_by_algorithm(
CongestionControlAlgorithm::Cubic,
CWND_INITIAL * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR,
cc,
cwnd_initial * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR,
lost_packets,
persistent_expected,
);
Expand Down Expand Up @@ -840,33 +855,33 @@ mod tests {
rtt_time: u32,
lost: &[SentPacket],
) -> bool {
assert_eq!(cc.cwnd(), CWND_INITIAL);
assert_eq!(cc.cwnd(), cc.cwnd_initial());

let last_ack = Some(by_pto(last_ack));
let rtt_time = Some(by_pto(rtt_time));

// Persistent congestion is never declared if the RTT time is `None`.
cc.detect_persistent_congestion(None, None, PTO, lost);
assert_eq!(cc.cwnd(), CWND_INITIAL);
assert_eq!(cc.cwnd(), cc.cwnd_initial());
cc.detect_persistent_congestion(None, last_ack, PTO, lost);
assert_eq!(cc.cwnd(), CWND_INITIAL);
assert_eq!(cc.cwnd(), cc.cwnd_initial());

cc.detect_persistent_congestion(rtt_time, last_ack, PTO, lost);
cc.cwnd() == CWND_MIN
cc.cwnd() == cc.cwnd_min()
}

/// No persistent congestion can be had if there are no lost packets.
#[test]
fn persistent_congestion_no_lost() {
let lost = make_lost(&[]);
assert!(!persistent_congestion_by_pto(
ClassicCongestionControl::new(NewReno::default(), PmtudState::new()),
ClassicCongestionControl::new(NewReno::default(), PmtudState::new(IP_ADDR)),
0,
0,
&lost
));
assert!(!persistent_congestion_by_pto(
ClassicCongestionControl::new(Cubic::default(), PmtudState::new()),
ClassicCongestionControl::new(Cubic::default(), PmtudState::new(IP_ADDR)),
0,
0,
&lost
Expand All @@ -878,13 +893,13 @@ mod tests {
fn persistent_congestion_one_lost() {
let lost = make_lost(&[1]);
assert!(!persistent_congestion_by_pto(
ClassicCongestionControl::new(NewReno::default(), PmtudState::new()),
ClassicCongestionControl::new(NewReno::default(), PmtudState::new(IP_ADDR)),
0,
0,
&lost
));
assert!(!persistent_congestion_by_pto(
ClassicCongestionControl::new(Cubic::default(), PmtudState::new()),
ClassicCongestionControl::new(Cubic::default(), PmtudState::new(IP_ADDR)),
0,
0,
&lost
Expand All @@ -898,37 +913,37 @@ mod tests {
// sample are not considered. So 0 is ignored.
let lost = make_lost(&[0, PERSISTENT_CONG_THRESH + 1, PERSISTENT_CONG_THRESH + 2]);
assert!(!persistent_congestion_by_pto(
ClassicCongestionControl::new(NewReno::default(), PmtudState::new()),
ClassicCongestionControl::new(NewReno::default(), PmtudState::new(IP_ADDR)),
1,
1,
&lost
));
assert!(!persistent_congestion_by_pto(
ClassicCongestionControl::new(NewReno::default(), PmtudState::new()),
ClassicCongestionControl::new(NewReno::default(), PmtudState::new(IP_ADDR)),
0,
1,
&lost
));
assert!(!persistent_congestion_by_pto(
ClassicCongestionControl::new(NewReno::default(), PmtudState::new()),
ClassicCongestionControl::new(NewReno::default(), PmtudState::new(IP_ADDR)),
1,
0,
&lost
));
assert!(!persistent_congestion_by_pto(
ClassicCongestionControl::new(Cubic::default(), PmtudState::new()),
ClassicCongestionControl::new(Cubic::default(), PmtudState::new(IP_ADDR)),
1,
1,
&lost
));
assert!(!persistent_congestion_by_pto(
ClassicCongestionControl::new(Cubic::default(), PmtudState::new()),
ClassicCongestionControl::new(Cubic::default(), PmtudState::new(IP_ADDR)),
0,
1,
&lost
));
assert!(!persistent_congestion_by_pto(
ClassicCongestionControl::new(Cubic::default(), PmtudState::new()),
ClassicCongestionControl::new(Cubic::default(), PmtudState::new(IP_ADDR)),
1,
0,
&lost
Expand All @@ -949,13 +964,13 @@ mod tests {
lost[0].len(),
);
assert!(!persistent_congestion_by_pto(
ClassicCongestionControl::new(NewReno::default(), PmtudState::new()),
ClassicCongestionControl::new(NewReno::default(), PmtudState::new(IP_ADDR)),
0,
0,
&lost
));
assert!(!persistent_congestion_by_pto(
ClassicCongestionControl::new(Cubic::default(), PmtudState::new()),
ClassicCongestionControl::new(Cubic::default(), PmtudState::new(IP_ADDR)),
0,
0,
&lost
Expand All @@ -969,13 +984,13 @@ mod tests {
fn persistent_congestion_min() {
let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]);
assert!(persistent_congestion_by_pto(
ClassicCongestionControl::new(NewReno::default(), PmtudState::new()),
ClassicCongestionControl::new(NewReno::default(), PmtudState::new(IP_ADDR)),
0,
0,
&lost
));
assert!(persistent_congestion_by_pto(
ClassicCongestionControl::new(Cubic::default(), PmtudState::new()),
ClassicCongestionControl::new(Cubic::default(), PmtudState::new(IP_ADDR)),
0,
0,
&lost
Expand All @@ -988,17 +1003,17 @@ mod tests {
#[test]
fn persistent_congestion_no_prev_ack_newreno() {
let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]);
let mut cc = ClassicCongestionControl::new(NewReno::default(), PmtudState::new());
let mut cc = ClassicCongestionControl::new(NewReno::default(), PmtudState::new(IP_ADDR));
cc.detect_persistent_congestion(Some(by_pto(0)), None, PTO, &lost);
assert_eq!(cc.cwnd(), CWND_MIN);
assert_eq!(cc.cwnd(), cc.cwnd_min());
}

#[test]
fn persistent_congestion_no_prev_ack_cubic() {
let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]);
let mut cc = ClassicCongestionControl::new(Cubic::default(), PmtudState::new());
let mut cc = ClassicCongestionControl::new(Cubic::default(), PmtudState::new(IP_ADDR));
cc.detect_persistent_congestion(Some(by_pto(0)), None, PTO, &lost);
assert_eq!(cc.cwnd(), CWND_MIN);
assert_eq!(cc.cwnd(), cc.cwnd_min());
}

/// The code asserts on ordering errors.
Expand All @@ -1007,7 +1022,7 @@ mod tests {
fn persistent_congestion_unsorted_newreno() {
let lost = make_lost(&[PERSISTENT_CONG_THRESH + 2, 1]);
assert!(!persistent_congestion_by_pto(
ClassicCongestionControl::new(NewReno::default(), PmtudState::new()),
ClassicCongestionControl::new(NewReno::default(), PmtudState::new(IP_ADDR)),
0,
0,
&lost
Expand All @@ -1020,7 +1035,7 @@ mod tests {
fn persistent_congestion_unsorted_cubic() {
let lost = make_lost(&[PERSISTENT_CONG_THRESH + 2, 1]);
assert!(!persistent_congestion_by_pto(
ClassicCongestionControl::new(Cubic::default(), PmtudState::new()),
ClassicCongestionControl::new(Cubic::default(), PmtudState::new(IP_ADDR)),
0,
0,
&lost
Expand All @@ -1031,7 +1046,7 @@ mod tests {
fn app_limited_slow_start() {
const BELOW_APP_LIMIT_PKTS: usize = 5;
const ABOVE_APP_LIMIT_PKTS: usize = BELOW_APP_LIMIT_PKTS + 1;
let mut cc = ClassicCongestionControl::new(NewReno::default(), PmtudState::new());
let mut cc = ClassicCongestionControl::new(NewReno::default(), PmtudState::new(IP_ADDR));
let cwnd = cc.congestion_window;
let mut now = now();
let mut next_pn = 0;
Expand Down Expand Up @@ -1115,7 +1130,7 @@ mod tests {
const BELOW_APP_LIMIT_PKTS: usize = CWND_PKTS_CA - 2;
const ABOVE_APP_LIMIT_PKTS: usize = BELOW_APP_LIMIT_PKTS + 1;

let mut cc = ClassicCongestionControl::new(NewReno::default(), PmtudState::new());
let mut cc = ClassicCongestionControl::new(NewReno::default(), PmtudState::new(IP_ADDR));
let mut now = now();

// Change state to congestion avoidance by introducing loss.
Expand Down Expand Up @@ -1230,7 +1245,7 @@ mod tests {

#[test]
fn ecn_ce() {
let mut cc = ClassicCongestionControl::new(NewReno::default(), PmtudState::new());
let mut cc = ClassicCongestionControl::new(NewReno::default(), PmtudState::new(IP_ADDR));
let p_ce = SentPacket::new(
PacketType::Short,
1,
Expand Down
Loading

0 comments on commit 3cc307b

Please sign in to comment.