Skip to content

Commit

Permalink
Add retaining server list order option
Browse files Browse the repository at this point in the history
Signed-off-by: Tomasz Pietrek <tomasz@nats.io>
  • Loading branch information
Jarema committed Mar 22, 2023
1 parent ea5a07d commit fa1baef
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 14 deletions.
1 change: 1 addition & 0 deletions async-nats/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ thiserror = "1.0"
base64 = "0.13"
tokio-retry = "0.3"
ring = "0.16"
rand = "0.8"

[dev-dependencies]
criterion = { version = "0.3", features = ["async_tokio"]}
Expand Down
34 changes: 21 additions & 13 deletions async-nats/src/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use crate::ToServerAddrs;
use crate::LANG;
use crate::VERSION;
use bytes::BytesMut;
use rand::thread_rng;
use std::cmp;
use std::collections::HashMap;
use std::io;
use std::path::PathBuf;
use std::sync::Arc;
Expand All @@ -53,13 +53,15 @@ pub(crate) struct ConnectorOptions {
pub(crate) connection_timeout: Duration,
pub(crate) name: Option<String>,
pub(crate) ignore_discovered_servers: bool,
pub(crate) retain_servers_order: bool,
}

/// Maintains a list of servers and establishes connections.
pub(crate) struct Connector {
/// A map of servers and number of connect attempts.
servers: HashMap<ServerAddr, usize>,
servers: Vec<(ServerAddr, usize)>,
options: ConnectorOptions,
attempts: usize,
pub(crate) events_tx: tokio::sync::mpsc::Sender<Event>,
pub(crate) state_tx: tokio::sync::watch::Sender<State>,
}
Expand All @@ -74,6 +76,7 @@ impl Connector {
let servers = addrs.to_server_addrs()?.map(|addr| (addr, 0)).collect();

Ok(Connector {
attempts: 0,
servers,
options,
events_tx,
Expand All @@ -96,21 +99,26 @@ impl Connector {
}

pub(crate) async fn try_connect(&mut self) -> Result<(ServerInfo, Connection), ConnectError> {
use rand::seq::SliceRandom;
let mut error = None;

let server_addrs: Vec<ServerAddr> = self.servers.keys().cloned().collect();
for server_addr in server_addrs {
let server_attempts = self.servers.get_mut(&server_addr).unwrap();
let duration = if *server_attempts == 0 {
let mut servers = self.servers.clone();
if !self.options.retain_servers_order {
servers.shuffle(&mut thread_rng());
// sort_by is stable, meaning it will retain the order for equal elements.
servers.sort_by(|a, b| a.1.cmp(&b.1));
}

for (server_addr, _) in servers {
let duration = if self.attempts == 0 {
Duration::from_millis(0)
} else {
let exp: u32 = (*server_attempts - 1).try_into().unwrap_or(std::u32::MAX);
let exp: u32 = (self.attempts - 1).try_into().unwrap_or(std::u32::MAX);
let max = Duration::from_secs(4);

cmp::min(Duration::from_millis(2_u64.saturating_pow(exp)), max)
};

*server_attempts += 1;
self.attempts += 1;
sleep(duration).await;

let socket_addrs = server_addr
Expand All @@ -130,13 +138,12 @@ impl Connector {
err,
)
})?;
self.servers.entry(server_addr).or_insert(0);
if !self.servers.iter().any(|(addr, _)| addr == &server_addr) {
self.servers.push((server_addr, 0));
}
}
}

let server_attempts = self.servers.get_mut(&server_addr).unwrap();
*server_attempts = 0;

let tls_required = self.options.tls_required || server_addr.tls_required();
let mut connect_info = ConnectInfo {
tls_required,
Expand Down Expand Up @@ -227,6 +234,7 @@ impl Connector {
}
},
Some(_) => {
self.attempts = 0;
self.events_tx.send(Event::Connected).await.ok();
self.state_tx.send(State::Connected).ok();
return Ok((server_info, connection));
Expand Down
1 change: 1 addition & 0 deletions async-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,7 @@ pub async fn connect_with_options<A: ToServerAddrs>(
connection_timeout: options.connection_timeout,
name: options.name,
ignore_discovered_servers: options.ignore_discovered_servers,
retain_servers_order: options.retain_servers_order,
},
events_tx,
state_tx,
Expand Down
10 changes: 10 additions & 0 deletions async-nats/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ pub struct ConnectOptions {
pub(crate) request_timeout: Option<Duration>,
pub(crate) retry_on_initial_connect: bool,
pub(crate) ignore_discovered_servers: bool,
pub(crate) retain_servers_order: bool,
}

impl fmt::Debug for ConnectOptions {
Expand Down Expand Up @@ -107,6 +108,7 @@ impl Default for ConnectOptions {
request_timeout: Some(Duration::from_secs(10)),
retry_on_initial_connect: false,
ignore_discovered_servers: false,
retain_servers_order: false,
}
}
}
Expand Down Expand Up @@ -575,6 +577,14 @@ impl ConnectOptions {
self.ignore_discovered_servers = true;
self
}

/// By default, client will pick random server to which it will try connect to.
/// This option disables that feature, forcing it to always respect the order
/// in which server addresses were passed.
pub fn retain_servers_order(mut self) -> ConnectOptions {
self.retain_servers_order = true;
self
}
}
type AsyncCallbackArg1<A, T> =
Box<dyn Fn(A) -> Pin<Box<dyn Future<Output = T> + Send + Sync + 'static>> + Send + Sync>;
Expand Down
34 changes: 33 additions & 1 deletion async-nats/tests/client_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
mod client {
use async_nats::connection::State;
use async_nats::header::HeaderValue;
use async_nats::{ConnectErrorKind, ConnectOptions, Event, Request, RequestErrorKind};
use async_nats::{
ConnectErrorKind, ConnectOptions, Event, Request, RequestErrorKind, ServerAddr,
};
use bytes::Bytes;
use futures::future::join_all;
use futures::stream::StreamExt;
Expand Down Expand Up @@ -731,4 +733,34 @@ mod client {
let _server = nats_server::run_server_with_port("", Some("7777"));
sub.next().await.unwrap();
}

#[tokio::test]
async fn retained_servers_order() {
let mut servers = vec![
nats_server::run_basic_server(),
nats_server::run_basic_server(),
];
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let client = ConnectOptions::with_user_and_password("js".into(), "js".into())
.event_callback(move |event| {
let tx = tx.clone();
async move {
if let Event::Disconnected = event {
tx.send(()).unwrap();
}
}
})
.retain_servers_order()
.connect(
servers
.iter()
.map(|s| s.client_url().parse::<ServerAddr>().unwrap())
.collect::<Vec<ServerAddr>>(),
)
.await
.unwrap();

drop(servers.remove(0));
rx.recv().await;
}
}

0 comments on commit fa1baef

Please sign in to comment.