Skip to content

Commit

Permalink
Driver: Automate (re)connection logic (#81)
Browse files Browse the repository at this point in the history
This PR adds several enhancements to Driver connection logic:
* Driver (re)connection attempts now have a default timeout of around 10s.
* The driver will now attempt to retry full connection attempts using a user-provided strategy: currently, this defaults to 5 attempts under an exponential backoff strategy.
* The driver will now fire `DriverDisconnect` events at the end of any session -- this unifies (re)connection failure events with session expiry as seen in #76, which should provide users with enough detail to know *which* voice channel to reconnect to. Users still need to be careful to read the session/channel IDs to ensure that they aren't overwriting another join.

This has been tested using `cargo make ready`, and by setting low timeouts to force failures in the voice receive example (with some additional error handlers).

Closes #68.
  • Loading branch information
FelixMcFelix committed Jul 1, 2021
1 parent 8381f8c commit 210e3ae
Show file tree
Hide file tree
Showing 17 changed files with 672 additions and 90 deletions.
33 changes: 31 additions & 2 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#[cfg(feature = "driver-core")]
use super::driver::{CryptoMode, DecodeMode};
use super::driver::{retry::Retry, CryptoMode, DecodeMode};

#[cfg(feature = "gateway-core")]
use std::time::Duration;

/// Configuration for drivers and calls.
Expand Down Expand Up @@ -61,6 +60,20 @@ pub struct Config {
/// Changes to this field in a running driver will only ever increase
/// the capacity of the track store.
pub preallocated_tracks: usize,
#[cfg(feature = "driver-core")]
/// Connection retry logic for the [`Driver`].
///
/// This controls how many times the [`Driver`] should retry any connections,
/// as well as how long to wait between attempts.
///
/// [`Driver`]: crate::driver::Driver
pub driver_retry: Retry,
#[cfg(feature = "driver-core")]
/// Configures the maximum amount of time to wait for an attempted voice
/// connection to Discord.
///
/// Defaults to 10 seconds. If set to `None`, connections will never time out.
pub driver_timeout: Option<Duration>,
}

impl Default for Config {
Expand All @@ -74,6 +87,10 @@ impl Default for Config {
gateway_timeout: Some(Duration::from_secs(10)),
#[cfg(feature = "driver-core")]
preallocated_tracks: 1,
#[cfg(feature = "driver-core")]
driver_retry: Default::default(),
#[cfg(feature = "driver-core")]
driver_timeout: Some(Duration::from_secs(10)),
}
}
}
Expand All @@ -98,6 +115,18 @@ impl Config {
self
}

/// Sets this `Config`'s timeout for establishing a voice connection.
pub fn driver_timeout(mut self, driver_timeout: Option<Duration>) -> Self {
self.driver_timeout = driver_timeout;
self
}

/// Sets this `Config`'s voice connection retry configuration.
pub fn driver_retry(mut self, driver_retry: Retry) -> Self {
self.driver_retry = driver_retry;
self
}

/// This is used to prevent changes which would invalidate the current session.
pub(crate) fn make_safe(&mut self, previous: &Config, connected: bool) {
if connected {
Expand Down
14 changes: 14 additions & 0 deletions src/driver/connection/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ use crate::{
use flume::SendError;
use serde_json::Error as JsonError;
use std::{error::Error as StdError, fmt, io::Error as IoError};
#[cfg(not(feature = "tokio-02-marker"))]
use tokio::time::error::Elapsed;
#[cfg(feature = "tokio-02-marker")]
use tokio_compat::time::Elapsed;
use xsalsa20poly1305::aead::Error as CryptoError;

/// Errors encountered while connecting to a Discord voice server over the driver.
Expand Down Expand Up @@ -38,6 +42,8 @@ pub enum Error {
InterconnectFailure(Recipient),
/// Error communicating with gateway server over WebSocket.
Ws(WsError),
/// Connection attempt timed out.
TimedOut,
}

impl From<CryptoError> for Error {
Expand Down Expand Up @@ -82,6 +88,12 @@ impl From<WsError> for Error {
}
}

impl From<Elapsed> for Error {
fn from(_e: Elapsed) -> Error {
Error::TimedOut
}
}

impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "failed to connect to Discord RTP server: ")?;
Expand All @@ -99,6 +111,7 @@ impl fmt::Display for Error {
Json(e) => e.fmt(f),
InterconnectFailure(e) => write!(f, "failed to contact other task ({:?})", e),
Ws(e) => write!(f, "websocket issue ({:?}).", e),
TimedOut => write!(f, "connection attempt timed out"),
}
}
}
Expand All @@ -118,6 +131,7 @@ impl StdError for Error {
Error::Json(e) => e.source(),
Error::InterconnectFailure(_) => None,
Error::Ws(_) => None,
Error::TimedOut => None,
}
}
}
Expand Down
31 changes: 28 additions & 3 deletions src/driver/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ use error::{Error, Result};
use flume::Sender;
use std::{net::IpAddr, str::FromStr, sync::Arc};
#[cfg(not(feature = "tokio-02-marker"))]
use tokio::{net::UdpSocket, spawn};
use tokio::{net::UdpSocket, spawn, time::timeout};
#[cfg(feature = "tokio-02-marker")]
use tokio_compat::{net::UdpSocket, spawn};
use tokio_compat::{net::UdpSocket, spawn, time::timeout};
use tracing::{debug, info, instrument};
use url::Url;
use xsalsa20poly1305::{aead::NewAead, XSalsa20Poly1305 as Cipher};
Expand All @@ -42,9 +42,23 @@ pub(crate) struct Connection {

impl Connection {
pub(crate) async fn new(
info: ConnectionInfo,
interconnect: &Interconnect,
config: &Config,
idx: usize,
) -> Result<Connection> {
if let Some(t) = config.driver_timeout {
timeout(t, Connection::new_inner(info, interconnect, config, idx)).await?
} else {
Connection::new_inner(info, interconnect, config, idx).await
}
}

pub(crate) async fn new_inner(
mut info: ConnectionInfo,
interconnect: &Interconnect,
config: &Config,
idx: usize,
) -> Result<Connection> {
let url = generate_url(&mut info.endpoint)?;

Expand Down Expand Up @@ -207,6 +221,8 @@ impl Connection {
client,
ssrc,
hello.heartbeat_interval,
idx,
info.clone(),
));

spawn(udp_rx::runner(
Expand All @@ -226,7 +242,16 @@ impl Connection {
}

#[instrument(skip(self))]
pub async fn reconnect(&mut self) -> Result<()> {
pub async fn reconnect(&mut self, config: &Config) -> Result<()> {
if let Some(t) = config.driver_timeout {
timeout(t, self.reconnect_inner()).await?
} else {
self.reconnect_inner().await
}
}

#[instrument(skip(self))]
pub async fn reconnect_inner(&mut self) -> Result<()> {
let url = generate_url(&mut self.info.endpoint)?;

// Thread may have died, we want to send to prompt a clean exit
Expand Down
1 change: 1 addition & 0 deletions src/driver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub mod bench_internals;
pub(crate) mod connection;
mod crypto;
mod decode_mode;
pub mod retry;
pub(crate) mod tasks;

use connection::error::{Error, Result};
Expand Down
49 changes: 49 additions & 0 deletions src/driver/retry/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//! Configuration for connection retries.

mod strategy;

pub use self::strategy::*;

use std::time::Duration;

/// Configuration to be used for retrying driver connection attempts.
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Retry {
/// Strategy used to determine how long to wait between retry attempts.
///
/// *Defaults to an [`ExponentialBackoff`] from 0.25s
/// to 10s, with a jitter of `0.1`.*
///
/// [`ExponentialBackoff`]: Strategy::Backoff
pub strategy: Strategy,
/// The maximum number of retries to attempt.
///
/// `None` will attempt an infinite number of retries,
/// while `Some(0)` will attempt to connect *once* (no retries).
///
/// *Defaults to `Some(5)`.*
pub retry_limit: Option<usize>,
}

impl Default for Retry {
fn default() -> Self {
Self {
strategy: Strategy::Backoff(Default::default()),
retry_limit: Some(5),
}
}
}

impl Retry {
pub(crate) fn retry_in(
&self,
last_wait: Option<Duration>,
attempts: usize,
) -> Option<Duration> {
if self.retry_limit.map(|a| attempts < a).unwrap_or(true) {
Some(self.strategy.retry_in(last_wait))
} else {
None
}
}
}
84 changes: 84 additions & 0 deletions src/driver/retry/strategy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use rand::random;
use std::time::Duration;

/// Logic used to determine how long to wait between retry attempts.
#[derive(Clone, Copy, Debug, PartialEq)]
#[non_exhaustive]
pub enum Strategy {
/// The driver will wait for the same amount of time between each retry.
Every(Duration),
/// Exponential backoff waiting strategy, where the duration between
/// attempts (approximately) doubles each time.
Backoff(ExponentialBackoff),
}

impl Strategy {
pub(crate) fn retry_in(&self, last_wait: Option<Duration>) -> Duration {
match self {
Self::Every(t) => *t,
Self::Backoff(exp) => exp.retry_in(last_wait),
}
}
}

/// Exponential backoff waiting strategy.
///
/// Each attempt waits for twice the last delay plus/minus a
/// random jitter, clamped to a min and max value.
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct ExponentialBackoff {
/// Minimum amount of time to wait between retries.
///
/// *Defaults to 0.25s.*
pub min: Duration,
/// Maximum amount of time to wait between retries.
///
/// This will be clamped to `>=` min.
///
/// *Defaults to 10s.*
pub max: Duration,
/// Amount of uniform random jitter to apply to generated wait times.
/// I.e., 0.1 will add +/-10% to generated intervals.
///
/// This is restricted to within +/-100%.
///
/// *Defaults to `0.1`.*
pub jitter: f32,
}

impl Default for ExponentialBackoff {
fn default() -> Self {
Self {
min: Duration::from_millis(250),
max: Duration::from_secs(10),
jitter: 0.1,
}
}
}

impl ExponentialBackoff {
pub(crate) fn retry_in(&self, last_wait: Option<Duration>) -> Duration {
let attempt = last_wait.map(|t| 2 * t).unwrap_or(self.min);
let perturb = (1.0 - (self.jitter * 2.0 * (random::<f32>() - 1.0)))
.max(0.0)
.min(2.0);
let mut target_time = attempt.mul_f32(perturb);

// Now clamp target time into given range.
let safe_max = if self.max < self.min {
self.min
} else {
self.max
};

if target_time > safe_max {
target_time = safe_max;
}

if target_time < self.min {
target_time = self.min;
}

target_time
}
}
4 changes: 3 additions & 1 deletion src/driver/tasks/message/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use crate::{
driver::{connection::error::Error, Bitrate, Config},
events::EventData,
events::{context_data::DisconnectReason, EventData},
tracks::Track,
ConnectionInfo,
};
Expand All @@ -12,6 +12,8 @@ use flume::Sender;
#[derive(Debug)]
pub enum CoreMessage {
ConnectWithResult(ConnectionInfo, Sender<Result<(), Error>>),
RetryConnect(usize),
SignalWsClosure(usize, ConnectionInfo, Option<DisconnectReason>),
Disconnect,
SetTrack(Option<Track>),
AddTrack(Track),
Expand Down
Loading

0 comments on commit 210e3ae

Please sign in to comment.