diff --git a/Cargo.toml b/Cargo.toml index f6a47f550..36a29b7bf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,10 @@ serde_json = "1" tracing = { version = "0.1", features = ["log"] } tracing-futures = "0.2" +[dependencies.once_cell] +version = "1" +optional = true + [dependencies.async-trait] optional = true version = "0.1" @@ -168,6 +172,7 @@ default = [ ] gateway = [ "dashmap", + "once_cell", "flume", "parking_lot", "tokio/sync", diff --git a/src/manager.rs b/src/manager.rs index 02c6802ed..ef811451d 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -13,6 +13,7 @@ use async_trait::async_trait; use dashmap::DashMap; #[cfg(feature = "serenity")] use futures::channel::mpsc::UnboundedSender as Sender; +use once_cell::sync::OnceCell; use parking_lot::RwLock as PRwLock; #[cfg(feature = "serenity")] use serenity::{ @@ -45,7 +46,7 @@ struct ClientData { /// [`Call`]: Call #[derive(Debug)] pub struct Songbird { - client_data: PRwLock>, + client_data: OnceCell, calls: DashMap>>, sharder: Sharder, config: PRwLock>, @@ -72,7 +73,7 @@ impl Songbird { #[must_use] pub fn serenity_from_config(config: Config) -> Arc { Arc::new(Self { - client_data: PRwLock::new(None), + client_data: OnceCell::new(), calls: DashMap::new(), sharder: Sharder::Serenity(SerenitySharder::default()), config: Some(config).into(), @@ -107,10 +108,10 @@ impl Songbird { U: Into, { Self { - client_data: PRwLock::new(Some(ClientData { + client_data: OnceCell::with_value(ClientData { shard_count: cluster.config().shard_scheme().total(), user_id: user_id.into(), - })), + }), calls: DashMap::new(), sharder: Sharder::TwilightCluster(cluster), config: Some(config).into(), @@ -124,16 +125,12 @@ impl Songbird { /// /// [`::twilight`]: #method.twilight pub fn initialise_client_data>(&self, shard_count: u64, user_id: U) { - let mut client_data = self.client_data.write(); - - if client_data.is_some() { - return; - } - - *client_data = Some(ClientData { - shard_count, - user_id: user_id.into(), - }); + self.client_data + .set(ClientData { + shard_count, + user_id: user_id.into(), + }) + .ok(); } /// Retrieves a [`Call`] for the given guild, if one already exists. @@ -166,7 +163,7 @@ impl Songbird { .or_insert_with(|| { let info = self .client_data - .read() + .get() .expect("Manager has not been initialised"); let shard = shard_id(guild_id.0.get(), info.shard_count); @@ -374,8 +371,7 @@ impl Songbird { TwilightEvent::VoiceStateUpdate(v) => { if self .client_data - .read() - .as_ref() + .get() .map_or(true, |data| v.0.user_id.into_nonzero() != data.user_id.0) { return; @@ -410,7 +406,7 @@ impl VoiceGatewayManager for Songbird { "Registering Serenity shard handle {} with Songbird", shard_id ); - self.sharder.register_shard_handle(shard_id as u64, sender); + self.sharder.register_shard_handle(shard_id, sender); debug!("Registered shard handle {}.", shard_id); } @@ -419,7 +415,7 @@ impl VoiceGatewayManager for Songbird { "Deregistering Serenity shard handle {} with Songbird", shard_id ); - self.sharder.deregister_shard_handle(shard_id as u64); + self.sharder.deregister_shard_handle(shard_id); debug!("Deregistered shard handle {}.", shard_id); } @@ -435,7 +431,7 @@ impl VoiceGatewayManager for Songbird { async fn state_update(&self, guild_id: SerenityGuild, voice_state: &VoiceState) { if self .client_data - .read() + .get() .map_or(true, |data| voice_state.user_id.0 != data.user_id.0) { return; diff --git a/src/shards.rs b/src/shards.rs index a35368794..5163bf5a0 100644 --- a/src/shards.rs +++ b/src/shards.rs @@ -2,6 +2,8 @@ use crate::{error::JoinResult, id::*}; use async_trait::async_trait; +#[cfg(feature = "serenity")] +use dashmap::DashMap; use derivative::Derivative; #[cfg(feature = "serenity")] use futures::channel::mpsc::{TrySendError, UnboundedSender as Sender}; @@ -12,7 +14,7 @@ use serde_json::json; use serenity::gateway::InterMessage; use std::sync::Arc; #[cfg(feature = "serenity")] -use std::{collections::HashMap, result::Result as StdResult}; +use std::result::Result as StdResult; use tracing::{debug, error}; #[cfg(feature = "twilight")] use twilight_gateway::{Cluster, Shard as TwilightShard}; @@ -49,10 +51,11 @@ pub trait GenericSharder { impl Sharder { /// Returns a new handle to the required inner shard. + #[allow(clippy::must_use_candidate)] // get_or_insert_shard_handle has side effects pub fn get_shard(&self, shard_id: u64) -> Option { match self { #[cfg(feature = "serenity")] - Sharder::Serenity(s) => Some(Shard::Serenity(s.get_or_insert_shard_handle(shard_id))), + Sharder::Serenity(s) => Some(Shard::Serenity(s.get_or_insert_shard_handle(shard_id as u32))), #[cfg(feature = "twilight")] Sharder::TwilightCluster(t) => Some(Shard::TwilightCluster(t.clone(), shard_id)), #[cfg(feature = "twilight")] @@ -65,7 +68,7 @@ impl Sharder { #[cfg(feature = "serenity")] impl Sharder { #[allow(unreachable_patterns)] - pub(crate) fn register_shard_handle(&self, shard_id: u64, sender: Sender) { + pub(crate) fn register_shard_handle(&self, shard_id: u32, sender: Sender) { if let Sharder::Serenity(s) = self { s.register_shard_handle(shard_id, sender); } else { @@ -74,7 +77,7 @@ impl Sharder { } #[allow(unreachable_patterns)] - pub(crate) fn deregister_shard_handle(&self, shard_id: u64) { + pub(crate) fn deregister_shard_handle(&self, shard_id: u32) { if let Sharder::Serenity(s) = self { s.deregister_shard_handle(shard_id); } else { @@ -89,29 +92,22 @@ impl Sharder { /// /// This is updated and maintained by the library, and is designed to prevent /// message loss during rebalances and reconnects. -pub struct SerenitySharder(PRwLock>>); +pub struct SerenitySharder(DashMap>); #[cfg(feature = "serenity")] impl SerenitySharder { - fn get_or_insert_shard_handle(&self, shard_id: u64) -> Arc { - ({ - let map_read = self.0.read(); - map_read.get(&shard_id).cloned() - }) - .unwrap_or_else(|| { - let mut map_read = self.0.write(); - map_read.entry(shard_id).or_default().clone() - }) + fn get_or_insert_shard_handle(&self, shard_id: u32) -> Arc { + self.0.entry(shard_id).or_default().clone() } - fn register_shard_handle(&self, shard_id: u64, sender: Sender) { + fn register_shard_handle(&self, shard_id: u32, sender: Sender) { // Write locks are only used to add new entries to the map. let handle = self.get_or_insert_shard_handle(shard_id); handle.register(sender); } - fn deregister_shard_handle(&self, shard_id: u64) { + fn deregister_shard_handle(&self, shard_id: u32) { // Write locks are only used to add new entries to the map. let handle = self.get_or_insert_shard_handle(shard_id);