diff --git a/quinn-proto/src/config/mod.rs b/quinn-proto/src/config/mod.rs index 654e663f3..9cb4a45ed 100644 --- a/quinn-proto/src/config/mod.rs +++ b/quinn-proto/src/config/mod.rs @@ -17,7 +17,7 @@ use crate::{ cid_generator::{ConnectionIdGenerator, HashedConnectionIdGenerator}, crypto::{self, HandshakeTokenKey, HmacKey}, shared::ConnectionId, - Duration, RandomConnectionIdGenerator, VarInt, VarIntBoundsExceeded, + Duration, RandomConnectionIdGenerator, SystemTime, VarInt, VarIntBoundsExceeded, DEFAULT_SUPPORTED_VERSIONS, MAX_CID_SIZE, }; @@ -216,6 +216,8 @@ pub struct ServerConfig { pub(crate) max_incoming: usize, pub(crate) incoming_buffer_size: u64, pub(crate) incoming_buffer_size_total: u64, + + pub(crate) time_source: Arc, } impl ServerConfig { @@ -239,6 +241,8 @@ impl ServerConfig { max_incoming: 1 << 16, incoming_buffer_size: 10 << 20, incoming_buffer_size_total: 100 << 20, + + time_source: Arc::new(StdSystemTime), } } @@ -334,6 +338,16 @@ impl ServerConfig { self.incoming_buffer_size_total = incoming_buffer_size_total; self } + + /// Object to get current [`SystemTime`] + /// + /// This exists to allow system time to be mocked in tests, or wherever else desired. + /// + /// Defaults to [`StdSystemTime`], which simply calls [`SystemTime::now()`](SystemTime::now). + pub fn time_source(&mut self, time_source: Arc) -> &mut Self { + self.time_source = time_source; + self + } } #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))] @@ -388,6 +402,7 @@ impl fmt::Debug for ServerConfig { "incoming_buffer_size_total", &self.incoming_buffer_size_total, ) + // system_time_clock not debug .finish_non_exhaustive() } } @@ -503,3 +518,22 @@ impl From for ConfigError { Self::OutOfBounds } } + +/// Object to get current [`SystemTime`] +/// +/// This exists to allow system time to be mocked in tests, or wherever else desired. +pub trait TimeSource: Send + Sync { + /// Get [`SystemTime::now()`](SystemTime::now) or the mocked equivalent + fn now(&self) -> SystemTime; +} + +/// Default implementation of [`TimeSource`] +/// +/// Implements `now` by calling [`SystemTime::now()`](SystemTime::now). +pub struct StdSystemTime; + +impl TimeSource for StdSystemTime { + fn now(&self) -> SystemTime { + SystemTime::now() + } +} diff --git a/quinn-proto/src/endpoint.rs b/quinn-proto/src/endpoint.rs index 9f9820261..b0a5e269c 100644 --- a/quinn-proto/src/endpoint.rs +++ b/quinn-proto/src/endpoint.rs @@ -31,8 +31,8 @@ use crate::{ }, token, transport_parameters::{PreferredAddress, TransportParameters}, - Duration, Instant, ResetToken, RetryToken, Side, SystemTime, Transmit, TransportConfig, - TransportError, INITIAL_MTU, MAX_CID_SIZE, MIN_INITIAL_SIZE, RESET_TOKEN_SIZE, + Duration, Instant, ResetToken, RetryToken, Side, Transmit, TransportConfig, TransportError, + INITIAL_MTU, MAX_CID_SIZE, MIN_INITIAL_SIZE, RESET_TOKEN_SIZE, }; /// The main entry point to the library @@ -506,7 +506,8 @@ impl Endpoint { &header.token, ) { Ok(token) - if token.issued + server_config.retry_token_lifetime > SystemTime::now() => + if token.issued + server_config.retry_token_lifetime + > server_config.time_source.now() => { (Some(header.dst_cid), token.orig_dst_cid) } @@ -769,7 +770,7 @@ impl Endpoint { let token = RetryToken { orig_dst_cid: incoming.packet.header.dst_cid, - issued: SystemTime::now(), + issued: server_config.time_source.now(), } .encode( &*server_config.token_key, diff --git a/quinn-proto/src/lib.rs b/quinn-proto/src/lib.rs index c12a479f3..67f13cb66 100644 --- a/quinn-proto/src/lib.rs +++ b/quinn-proto/src/lib.rs @@ -51,7 +51,7 @@ pub use rustls; mod config; pub use config::{ AckFrequencyConfig, ClientConfig, ConfigError, EndpointConfig, IdleTimeout, MtuDiscoveryConfig, - ServerConfig, TransportConfig, + ServerConfig, StdSystemTime, TimeSource, TransportConfig, }; pub mod crypto; diff --git a/quinn/src/lib.rs b/quinn/src/lib.rs index 84a3821f6..4f6e6ea73 100644 --- a/quinn/src/lib.rs +++ b/quinn/src/lib.rs @@ -65,8 +65,9 @@ pub use proto::{ congestion, crypto, AckFrequencyConfig, ApplicationClose, Chunk, ClientConfig, ClosedStream, ConfigError, ConnectError, ConnectionClose, ConnectionError, ConnectionId, ConnectionIdGenerator, ConnectionStats, Dir, EcnCodepoint, EndpointConfig, FrameStats, - FrameType, IdleTimeout, MtuDiscoveryConfig, PathStats, ServerConfig, Side, StreamId, Transmit, - TransportConfig, TransportErrorCode, UdpStats, VarInt, VarIntBoundsExceeded, Written, + FrameType, IdleTimeout, MtuDiscoveryConfig, PathStats, ServerConfig, Side, StdSystemTime, + StreamId, TimeSource, Transmit, TransportConfig, TransportErrorCode, UdpStats, VarInt, + VarIntBoundsExceeded, Written, }; #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))] pub use rustls;