diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2c4bd2e12..838a92c2f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -48,13 +48,13 @@ jobs: os: windows-latest dont-test: true - name: driver only - features: driver rustls + features: driver tungstenite rustls dont-test: true - name: gateway only - features: gateway serenity rustls + features: gateway serenity tungstenite rustls dont-test: true - name: simd json - features: simd-json serenity rustls driver gateway serenity?/simd_json + features: simd-json serenity tungstenite rustls driver gateway serenity?/simd_json rustflags: -C target-cpu=native dont-test: true steps: diff --git a/Cargo.toml b/Cargo.toml index a626c7ea6..b6b6138f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,6 +45,7 @@ symphonia = { default_features = false, optional = true, version = "0.5.2" } symphonia-core = { optional = true, version = "0.5.2" } tokio = { default-features = false, optional = true, version = "1.0" } tokio-tungstenite = { optional = true, version = "0.21" } +tokio-websockets = { optional = true, version = "0.5", features = ["client", "fastrand", "sha1_smol", "simd"] } tokio-util = { features = ["io"], optional = true, version = "0.7" } tracing = { version = "0.1", features = ["log"] } tracing-futures = "0.2" @@ -67,6 +68,7 @@ default = [ "gateway", "rustls", "serenity", + "tungstenite" ] gateway = [ "dep:async-trait", @@ -98,7 +100,6 @@ driver = [ "dep:symphonia", "dep:symphonia-core", "dep:tokio", - "dep:tokio-tungstenite", "dep:tokio-util", "dep:url", "dep:uuid", @@ -115,14 +116,19 @@ rustls = [ "reqwest?/rustls-tls", "serenity?/rustls_backend", "tokio-tungstenite?/rustls-tls-webpki-roots", + "tokio-websockets?/ring", + "tokio-websockets?/rustls-native-roots", "twilight-gateway?/rustls-native-roots", ] native = [ "reqwest?/native-tls", "serenity?/native_tls_backend", "tokio-tungstenite?/native-tls", + "tokio-websockets?/native-tls", "twilight-gateway?/native", ] +tungstenite = ["dep:tokio-tungstenite"] +tws = ["dep:tokio-websockets"] twilight = ["dep:twilight-gateway","dep:twilight-model"] # Behaviour altering features. @@ -130,7 +136,7 @@ builtin-queue = [] receive = ["dep:bytes", "discortp?/demux", "discortp?/rtcp"] # Used for docgen/testing/benchmarking. -full-doc = ["default", "twilight", "builtin-queue", "receive"] +full-doc = ["default", "tungstenite", "twilight", "builtin-queue", "receive"] internals = ["dep:byteorder"] [lib] diff --git a/examples/twilight/Cargo.toml b/examples/twilight/Cargo.toml index 80d7d936e..6f17af0ba 100644 --- a/examples/twilight/Cargo.toml +++ b/examples/twilight/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" [dependencies] futures = "0.3" reqwest = { workspace = true } -songbird = { workspace = true, features = ["driver", "gateway", "twilight", "rustls"] } +songbird = { workspace = true, features = ["driver", "gateway", "twilight", "rustls", "tungstenite"] } symphonia = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } diff --git a/src/driver/tasks/ws.rs b/src/driver/tasks/ws.rs index 3d7697eee..3c8d91904 100644 --- a/src/driver/tasks/ws.rs +++ b/src/driver/tasks/ws.rs @@ -20,6 +20,7 @@ use tokio::{ select, time::{sleep_until, Instant}, }; +#[cfg(feature = "tungstenite")] use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode; use tracing::{debug, info, instrument, trace, warn}; @@ -241,6 +242,7 @@ pub(crate) async fn runner(mut interconnect: Interconnect, mut aux: AuxNetwork) fn ws_error_is_not_final(err: &WsError) -> bool { match err { + #[cfg(feature = "tungstenite")] WsError::WsClosed(Some(frame)) => match frame.code { CloseCode::Library(l) => if let Some(code) = VoiceCloseCode::from_u16(l) { @@ -250,6 +252,16 @@ fn ws_error_is_not_final(err: &WsError) -> bool { }, _ => true, }, + #[cfg(feature = "tws")] + WsError::WsClosed(Some(code)) => match (*code).into() { + code @ 4000..=4999_u16 => + if let Some(code) = VoiceCloseCode::from_u16(code) { + code.should_resume() + } else { + true + }, + _ => true, + }, e => { debug!("Error sending/receiving ws {:?}.", e); true diff --git a/src/events/context/data/disconnect.rs b/src/events/context/data/disconnect.rs index 4275849b1..16425a7da 100644 --- a/src/events/context/data/disconnect.rs +++ b/src/events/context/data/disconnect.rs @@ -4,6 +4,7 @@ use crate::{ model::{CloseCode as VoiceCloseCode, FromPrimitive}, ws::Error as WsError, }; +#[cfg(feature = "tungstenite")] use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode; /// Voice connection details gathered at termination or failure. @@ -104,10 +105,16 @@ impl From<&ConnectionError> for DisconnectReason { impl From<&WsError> for DisconnectReason { fn from(e: &WsError) -> Self { Self::WsClosed(match e { + #[cfg(feature = "tungstenite")] WsError::WsClosed(Some(frame)) => match frame.code { CloseCode::Library(l) => VoiceCloseCode::from_u16(l), _ => None, }, + #[cfg(feature = "tws")] + WsError::WsClosed(Some(code)) => match (*code).into() { + code @ 4000..=4999_u16 => VoiceCloseCode::from_u16(code), + _ => None, + }, _ => None, }) } diff --git a/src/ws.rs b/src/ws.rs index 939ca4520..bda6230f2 100644 --- a/src/ws.rs +++ b/src/ws.rs @@ -5,6 +5,7 @@ use tokio::{ net::TcpStream, time::{timeout, Duration}, }; +#[cfg(feature = "tungstenite")] use tokio_tungstenite::{ tungstenite::{ error::Error as TungsteniteError, @@ -14,14 +15,30 @@ use tokio_tungstenite::{ MaybeTlsStream, WebSocketStream, }; +#[cfg(feature = "tws")] +use tokio_websockets::{ + CloseCode, + Error as TwsError, + Limits, + MaybeTlsStream, + Message, + WebSocketStream, +}; use tracing::{debug, instrument}; use url::Url; +#[cfg(any( + all(feature = "tws", feature = "tungstenite"), + all(not(feature = "tws"), not(feature = "tungstenite")) +))] +compile_error!("specify one of `features = [\"tungstenite\"]` (recommended w/ serenity) or `features = [\"tws\"]` (recommended w/ twilight)"); + pub struct WsStream(WebSocketStream>); impl WsStream { #[instrument] pub(crate) async fn connect(url: Url) -> Result { + #[cfg(feature = "tungstenite")] let (stream, _) = tokio_tungstenite::connect_async_with_config::( url, Some(Config { @@ -32,6 +49,13 @@ impl WsStream { true, ) .await?; + #[cfg(feature = "tws")] + let (stream, _) = tokio_websockets::ClientBuilder::new() + .limits(Limits::unlimited()) + .uri(url.as_str()) + .unwrap() // Any valid URL is a valid URI. + .connect() + .await?; Ok(Self(stream)) } @@ -53,11 +77,12 @@ impl WsStream { } pub(crate) async fn send_json(&mut self, value: &Event) -> Result<()> { - Ok(crate::json::to_string(value) - .map(Message::Text) - .map_err(Error::from) - .map(|m| self.0.send(m))? - .await?) + let res = crate::json::to_string(value); + #[cfg(feature = "tungstenite")] + let res = res.map(Message::Text); + #[cfg(feature = "tws")] + let res = res.map(Message::text); + Ok(res.map_err(Error::from).map(|m| self.0.send(m))?.await?) } } @@ -71,9 +96,15 @@ pub enum Error { /// As a result, only text messages are expected. UnexpectedBinaryMessage(Vec), + #[cfg(feature = "tungstenite")] Ws(TungsteniteError), + #[cfg(feature = "tws")] + Ws(TwsError), + #[cfg(feature = "tungstenite")] WsClosed(Option>), + #[cfg(feature = "tws")] + WsClosed(Option), } impl From for Error { @@ -82,16 +113,25 @@ impl From for Error { } } +#[cfg(feature = "tungstenite")] impl From for Error { fn from(e: TungsteniteError) -> Error { Error::Ws(e) } } +#[cfg(feature = "tws")] +impl From for Error { + fn from(e: TwsError) -> Self { + Error::Ws(e) + } +} + #[inline] #[allow(unused_unsafe)] pub(crate) fn convert_ws_message(message: Option) -> Result> { - Ok(match message { + #[cfg(feature = "tungstenite")] + return Ok(match message { // SAFETY: // simd-json::serde::from_str may leave an &mut str in a non-UTF state on failure. // The below is safe as we have taken ownership of the inner `String`, and if @@ -112,5 +152,33 @@ pub(crate) fn convert_ws_message(message: Option) -> Result None, - }) + }); + #[cfg(feature = "tws")] + return Ok(if let Some(message) = message { + if message.is_text() { + let mut payload = message.as_text().unwrap().to_owned(); + // SAFETY: + // simd-json::serde::from_str may leave an &mut str in a non-UTF state on failure. + // The below is safe as we have created an owned copy of the payload `&str`, and if + // failure occurs we forcibly re-validate its contents before logging. + (unsafe { crate::json::from_str(payload.as_mut_str()) }) + .map_err(|e| { + let safe_payload = String::from_utf8_lossy(payload.as_bytes()); + debug!("Unexpected JSON: {e}. Payload: {safe_payload}"); + e + }) + .ok() + } else if message.is_binary() { + return Err(Error::UnexpectedBinaryMessage( + message.into_payload().to_vec(), + )); + } else if message.is_close() { + return Err(Error::WsClosed(message.as_close().map(|(c, _)| c))); + } else { + // ping/pong; will also be internally handled by tokio-websockets + None + } + } else { + None + }); }