Skip to content

Commit

Permalink
Driver: Support tokio-websockets (#226)
Browse files Browse the repository at this point in the history
* Driver: Support `tokio-websockets`

* Fix bad feature flag

* Fix CI & examples features

* Use tungstenite in twilight example

* Error if none or both ws features are enabled

* Match `twilight-gateway` features
  • Loading branch information
decahedron1 committed Feb 28, 2024
1 parent 80d9627 commit 9bb3a68
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 13 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ jobs:
os: windows-latest
dont-test: true
- name: driver only
features: driver rustls
features: driver tungstenite rustls
dont-test: true
- name: gateway only
features: gateway serenity rustls
features: gateway serenity tungstenite rustls
dont-test: true
- name: simd json
features: simd-json serenity rustls driver gateway serenity?/simd_json
features: simd-json serenity tungstenite rustls driver gateway serenity?/simd_json
rustflags: -C target-cpu=native
dont-test: true
steps:
Expand Down
10 changes: 8 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ symphonia = { default_features = false, optional = true, version = "0.5.2" }
symphonia-core = { optional = true, version = "0.5.2" }
tokio = { default-features = false, optional = true, version = "1.0" }
tokio-tungstenite = { optional = true, version = "0.21" }
tokio-websockets = { optional = true, version = "0.5", features = ["client", "fastrand", "sha1_smol", "simd"] }
tokio-util = { features = ["io"], optional = true, version = "0.7" }
tracing = { version = "0.1", features = ["log"] }
tracing-futures = "0.2"
Expand All @@ -67,6 +68,7 @@ default = [
"gateway",
"rustls",
"serenity",
"tungstenite"
]
gateway = [
"dep:async-trait",
Expand Down Expand Up @@ -98,7 +100,6 @@ driver = [
"dep:symphonia",
"dep:symphonia-core",
"dep:tokio",
"dep:tokio-tungstenite",
"dep:tokio-util",
"dep:url",
"dep:uuid",
Expand All @@ -115,22 +116,27 @@ rustls = [
"reqwest?/rustls-tls",
"serenity?/rustls_backend",
"tokio-tungstenite?/rustls-tls-webpki-roots",
"tokio-websockets?/ring",
"tokio-websockets?/rustls-native-roots",
"twilight-gateway?/rustls-native-roots",
]
native = [
"reqwest?/native-tls",
"serenity?/native_tls_backend",
"tokio-tungstenite?/native-tls",
"tokio-websockets?/native-tls",
"twilight-gateway?/native",
]
tungstenite = ["dep:tokio-tungstenite"]
tws = ["dep:tokio-websockets"]
twilight = ["dep:twilight-gateway","dep:twilight-model"]

# Behaviour altering features.
builtin-queue = []
receive = ["dep:bytes", "discortp?/demux", "discortp?/rtcp"]

# Used for docgen/testing/benchmarking.
full-doc = ["default", "twilight", "builtin-queue", "receive"]
full-doc = ["default", "tungstenite", "twilight", "builtin-queue", "receive"]
internals = ["dep:byteorder"]

[lib]
Expand Down
2 changes: 1 addition & 1 deletion examples/twilight/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ edition = "2021"
[dependencies]
futures = "0.3"
reqwest = { workspace = true }
songbird = { workspace = true, features = ["driver", "gateway", "twilight", "rustls"] }
songbird = { workspace = true, features = ["driver", "gateway", "twilight", "rustls", "tungstenite"] }
symphonia = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
Expand Down
12 changes: 12 additions & 0 deletions src/driver/tasks/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use tokio::{
select,
time::{sleep_until, Instant},
};
#[cfg(feature = "tungstenite")]
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
use tracing::{debug, info, instrument, trace, warn};

Expand Down Expand Up @@ -241,6 +242,7 @@ pub(crate) async fn runner(mut interconnect: Interconnect, mut aux: AuxNetwork)

fn ws_error_is_not_final(err: &WsError) -> bool {
match err {
#[cfg(feature = "tungstenite")]
WsError::WsClosed(Some(frame)) => match frame.code {
CloseCode::Library(l) =>
if let Some(code) = VoiceCloseCode::from_u16(l) {
Expand All @@ -250,6 +252,16 @@ fn ws_error_is_not_final(err: &WsError) -> bool {
},
_ => true,
},
#[cfg(feature = "tws")]
WsError::WsClosed(Some(code)) => match (*code).into() {
code @ 4000..=4999_u16 =>
if let Some(code) = VoiceCloseCode::from_u16(code) {
code.should_resume()
} else {
true
},
_ => true,
},
e => {
debug!("Error sending/receiving ws {:?}.", e);
true
Expand Down
7 changes: 7 additions & 0 deletions src/events/context/data/disconnect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
model::{CloseCode as VoiceCloseCode, FromPrimitive},
ws::Error as WsError,
};
#[cfg(feature = "tungstenite")]
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;

/// Voice connection details gathered at termination or failure.
Expand Down Expand Up @@ -104,10 +105,16 @@ impl From<&ConnectionError> for DisconnectReason {
impl From<&WsError> for DisconnectReason {
fn from(e: &WsError) -> Self {
Self::WsClosed(match e {
#[cfg(feature = "tungstenite")]
WsError::WsClosed(Some(frame)) => match frame.code {
CloseCode::Library(l) => VoiceCloseCode::from_u16(l),
_ => None,
},
#[cfg(feature = "tws")]
WsError::WsClosed(Some(code)) => match (*code).into() {
code @ 4000..=4999_u16 => VoiceCloseCode::from_u16(code),
_ => None,
},
_ => None,
})
}
Expand Down
82 changes: 75 additions & 7 deletions src/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use tokio::{
net::TcpStream,
time::{timeout, Duration},
};
#[cfg(feature = "tungstenite")]
use tokio_tungstenite::{
tungstenite::{
error::Error as TungsteniteError,
Expand All @@ -14,14 +15,30 @@ use tokio_tungstenite::{
MaybeTlsStream,
WebSocketStream,
};
#[cfg(feature = "tws")]
use tokio_websockets::{
CloseCode,
Error as TwsError,
Limits,
MaybeTlsStream,
Message,
WebSocketStream,
};
use tracing::{debug, instrument};
use url::Url;

#[cfg(any(
all(feature = "tws", feature = "tungstenite"),
all(not(feature = "tws"), not(feature = "tungstenite"))
))]
compile_error!("specify one of `features = [\"tungstenite\"]` (recommended w/ serenity) or `features = [\"tws\"]` (recommended w/ twilight)");

pub struct WsStream(WebSocketStream<MaybeTlsStream<TcpStream>>);

impl WsStream {
#[instrument]
pub(crate) async fn connect(url: Url) -> Result<Self> {
#[cfg(feature = "tungstenite")]
let (stream, _) = tokio_tungstenite::connect_async_with_config::<Url>(
url,
Some(Config {
Expand All @@ -32,6 +49,13 @@ impl WsStream {
true,
)
.await?;
#[cfg(feature = "tws")]
let (stream, _) = tokio_websockets::ClientBuilder::new()
.limits(Limits::unlimited())
.uri(url.as_str())
.unwrap() // Any valid URL is a valid URI.
.connect()
.await?;

Ok(Self(stream))
}
Expand All @@ -53,11 +77,12 @@ impl WsStream {
}

pub(crate) async fn send_json(&mut self, value: &Event) -> Result<()> {
Ok(crate::json::to_string(value)
.map(Message::Text)
.map_err(Error::from)
.map(|m| self.0.send(m))?
.await?)
let res = crate::json::to_string(value);
#[cfg(feature = "tungstenite")]
let res = res.map(Message::Text);
#[cfg(feature = "tws")]
let res = res.map(Message::text);
Ok(res.map_err(Error::from).map(|m| self.0.send(m))?.await?)
}
}

Expand All @@ -71,9 +96,15 @@ pub enum Error {
/// As a result, only text messages are expected.
UnexpectedBinaryMessage(Vec<u8>),

#[cfg(feature = "tungstenite")]
Ws(TungsteniteError),
#[cfg(feature = "tws")]
Ws(TwsError),

#[cfg(feature = "tungstenite")]
WsClosed(Option<CloseFrame<'static>>),
#[cfg(feature = "tws")]
WsClosed(Option<CloseCode>),
}

impl From<JsonError> for Error {
Expand All @@ -82,16 +113,25 @@ impl From<JsonError> for Error {
}
}

#[cfg(feature = "tungstenite")]
impl From<TungsteniteError> for Error {
fn from(e: TungsteniteError) -> Error {
Error::Ws(e)
}
}

#[cfg(feature = "tws")]
impl From<TwsError> for Error {
fn from(e: TwsError) -> Self {
Error::Ws(e)
}
}

#[inline]
#[allow(unused_unsafe)]
pub(crate) fn convert_ws_message(message: Option<Message>) -> Result<Option<Event>> {
Ok(match message {
#[cfg(feature = "tungstenite")]
return Ok(match message {
// SAFETY:
// simd-json::serde::from_str may leave an &mut str in a non-UTF state on failure.
// The below is safe as we have taken ownership of the inner `String`, and if
Expand All @@ -112,5 +152,33 @@ pub(crate) fn convert_ws_message(message: Option<Message>) -> Result<Option<Even
},
// Ping/Pong message behaviour is internally handled by tungstenite.
_ => None,
})
});
#[cfg(feature = "tws")]
return Ok(if let Some(message) = message {
if message.is_text() {
let mut payload = message.as_text().unwrap().to_owned();
// SAFETY:
// simd-json::serde::from_str may leave an &mut str in a non-UTF state on failure.
// The below is safe as we have created an owned copy of the payload `&str`, and if
// failure occurs we forcibly re-validate its contents before logging.
(unsafe { crate::json::from_str(payload.as_mut_str()) })
.map_err(|e| {
let safe_payload = String::from_utf8_lossy(payload.as_bytes());
debug!("Unexpected JSON: {e}. Payload: {safe_payload}");
e
})
.ok()
} else if message.is_binary() {
return Err(Error::UnexpectedBinaryMessage(
message.into_payload().to_vec(),
));
} else if message.is_close() {
return Err(Error::WsClosed(message.as_close().map(|(c, _)| c)));
} else {
// ping/pong; will also be internally handled by tokio-websockets
None
}
} else {
None
});
}

0 comments on commit 9bb3a68

Please sign in to comment.