Skip to content

Commit

Permalink
Use OnceCell and DashMap (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
GnomedDev authored Jun 27, 2022
1 parent c069ac4 commit 8b23dcd
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 36 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ serde_json = "1"
tracing = { version = "0.1", features = ["log"] }
tracing-futures = "0.2"

[dependencies.once_cell]
version = "1"
optional = true

[dependencies.async-trait]
optional = true
version = "0.1"
Expand Down Expand Up @@ -168,6 +172,7 @@ default = [
]
gateway = [
"dashmap",
"once_cell",
"flume",
"parking_lot",
"tokio/sync",
Expand Down
36 changes: 16 additions & 20 deletions src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use async_trait::async_trait;
use dashmap::DashMap;
#[cfg(feature = "serenity")]
use futures::channel::mpsc::UnboundedSender as Sender;
use once_cell::sync::OnceCell;
use parking_lot::RwLock as PRwLock;
#[cfg(feature = "serenity")]
use serenity::{
Expand Down Expand Up @@ -45,7 +46,7 @@ struct ClientData {
/// [`Call`]: Call
#[derive(Debug)]
pub struct Songbird {
client_data: PRwLock<Option<ClientData>>,
client_data: OnceCell<ClientData>,
calls: DashMap<GuildId, Arc<Mutex<Call>>>,
sharder: Sharder,
config: PRwLock<Option<Config>>,
Expand All @@ -72,7 +73,7 @@ impl Songbird {
#[must_use]
pub fn serenity_from_config(config: Config) -> Arc<Self> {
Arc::new(Self {
client_data: PRwLock::new(None),
client_data: OnceCell::new(),
calls: DashMap::new(),
sharder: Sharder::Serenity(SerenitySharder::default()),
config: Some(config).into(),
Expand Down Expand Up @@ -107,10 +108,10 @@ impl Songbird {
U: Into<UserId>,
{
Self {
client_data: PRwLock::new(Some(ClientData {
client_data: OnceCell::with_value(ClientData {
shard_count: cluster.config().shard_scheme().total(),
user_id: user_id.into(),
})),
}),
calls: DashMap::new(),
sharder: Sharder::TwilightCluster(cluster),
config: Some(config).into(),
Expand All @@ -124,16 +125,12 @@ impl Songbird {
///
/// [`::twilight`]: #method.twilight
pub fn initialise_client_data<U: Into<UserId>>(&self, shard_count: u64, user_id: U) {
let mut client_data = self.client_data.write();

if client_data.is_some() {
return;
}

*client_data = Some(ClientData {
shard_count,
user_id: user_id.into(),
});
self.client_data
.set(ClientData {
shard_count,
user_id: user_id.into(),
})
.ok();
}

/// Retrieves a [`Call`] for the given guild, if one already exists.
Expand Down Expand Up @@ -166,7 +163,7 @@ impl Songbird {
.or_insert_with(|| {
let info = self
.client_data
.read()
.get()
.expect("Manager has not been initialised");

let shard = shard_id(guild_id.0.get(), info.shard_count);
Expand Down Expand Up @@ -374,8 +371,7 @@ impl Songbird {
TwilightEvent::VoiceStateUpdate(v) => {
if self
.client_data
.read()
.as_ref()
.get()
.map_or(true, |data| v.0.user_id.into_nonzero() != data.user_id.0)
{
return;
Expand Down Expand Up @@ -410,7 +406,7 @@ impl VoiceGatewayManager for Songbird {
"Registering Serenity shard handle {} with Songbird",
shard_id
);
self.sharder.register_shard_handle(shard_id as u64, sender);
self.sharder.register_shard_handle(shard_id, sender);
debug!("Registered shard handle {}.", shard_id);
}

Expand All @@ -419,7 +415,7 @@ impl VoiceGatewayManager for Songbird {
"Deregistering Serenity shard handle {} with Songbird",
shard_id
);
self.sharder.deregister_shard_handle(shard_id as u64);
self.sharder.deregister_shard_handle(shard_id);
debug!("Deregistered shard handle {}.", shard_id);
}

Expand All @@ -435,7 +431,7 @@ impl VoiceGatewayManager for Songbird {
async fn state_update(&self, guild_id: SerenityGuild, voice_state: &VoiceState) {
if self
.client_data
.read()
.get()
.map_or(true, |data| voice_state.user_id.0 != data.user_id.0)
{
return;
Expand Down
28 changes: 12 additions & 16 deletions src/shards.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

use crate::{error::JoinResult, id::*};
use async_trait::async_trait;
#[cfg(feature = "serenity")]
use dashmap::DashMap;
use derivative::Derivative;
#[cfg(feature = "serenity")]
use futures::channel::mpsc::{TrySendError, UnboundedSender as Sender};
Expand All @@ -12,7 +14,7 @@ use serde_json::json;
use serenity::gateway::InterMessage;
use std::sync::Arc;
#[cfg(feature = "serenity")]
use std::{collections::HashMap, result::Result as StdResult};
use std::result::Result as StdResult;
use tracing::{debug, error};
#[cfg(feature = "twilight")]
use twilight_gateway::{Cluster, Shard as TwilightShard};
Expand Down Expand Up @@ -49,10 +51,11 @@ pub trait GenericSharder {

impl Sharder {
/// Returns a new handle to the required inner shard.
#[allow(clippy::must_use_candidate)] // get_or_insert_shard_handle has side effects
pub fn get_shard(&self, shard_id: u64) -> Option<Shard> {
match self {
#[cfg(feature = "serenity")]
Sharder::Serenity(s) => Some(Shard::Serenity(s.get_or_insert_shard_handle(shard_id))),
Sharder::Serenity(s) => Some(Shard::Serenity(s.get_or_insert_shard_handle(shard_id as u32))),
#[cfg(feature = "twilight")]
Sharder::TwilightCluster(t) => Some(Shard::TwilightCluster(t.clone(), shard_id)),
#[cfg(feature = "twilight")]
Expand All @@ -65,7 +68,7 @@ impl Sharder {
#[cfg(feature = "serenity")]
impl Sharder {
#[allow(unreachable_patterns)]
pub(crate) fn register_shard_handle(&self, shard_id: u64, sender: Sender<InterMessage>) {
pub(crate) fn register_shard_handle(&self, shard_id: u32, sender: Sender<InterMessage>) {
if let Sharder::Serenity(s) = self {
s.register_shard_handle(shard_id, sender);
} else {
Expand All @@ -74,7 +77,7 @@ impl Sharder {
}

#[allow(unreachable_patterns)]
pub(crate) fn deregister_shard_handle(&self, shard_id: u64) {
pub(crate) fn deregister_shard_handle(&self, shard_id: u32) {
if let Sharder::Serenity(s) = self {
s.deregister_shard_handle(shard_id);
} else {
Expand All @@ -89,29 +92,22 @@ impl Sharder {
///
/// This is updated and maintained by the library, and is designed to prevent
/// message loss during rebalances and reconnects.
pub struct SerenitySharder(PRwLock<HashMap<u64, Arc<SerenityShardHandle>>>);
pub struct SerenitySharder(DashMap<u32, Arc<SerenityShardHandle>>);

#[cfg(feature = "serenity")]
impl SerenitySharder {
fn get_or_insert_shard_handle(&self, shard_id: u64) -> Arc<SerenityShardHandle> {
({
let map_read = self.0.read();
map_read.get(&shard_id).cloned()
})
.unwrap_or_else(|| {
let mut map_read = self.0.write();
map_read.entry(shard_id).or_default().clone()
})
fn get_or_insert_shard_handle(&self, shard_id: u32) -> Arc<SerenityShardHandle> {
self.0.entry(shard_id).or_default().clone()
}

fn register_shard_handle(&self, shard_id: u64, sender: Sender<InterMessage>) {
fn register_shard_handle(&self, shard_id: u32, sender: Sender<InterMessage>) {
// Write locks are only used to add new entries to the map.
let handle = self.get_or_insert_shard_handle(shard_id);

handle.register(sender);
}

fn deregister_shard_handle(&self, shard_id: u64) {
fn deregister_shard_handle(&self, shard_id: u32) {
// Write locks are only used to add new entries to the map.
let handle = self.get_or_insert_shard_handle(shard_id);

Expand Down

0 comments on commit 8b23dcd

Please sign in to comment.