From 2fbcdd1cd01dd6b9bda3cbe145f61f145023af09 Mon Sep 17 00:00:00 2001 From: Tom Milligan Date: Fri, 23 Aug 2024 17:44:57 +0100 Subject: [PATCH] feature: Add multi-node connection pool (#189) * Implement static node list connection pool * trait for setting connection distribution. Defaults to RoundRobin. * Allow reseeding of nodes on MultiNodeConnection * Implement Sniff Nodes request --------- Co-authored-by: Stephen Leyva Co-authored-by: Sylvain Wallez --- elasticsearch/Cargo.toml | 2 +- elasticsearch/src/http/transport.rs | 389 +++++++++++++++++++++++++++- 2 files changed, 376 insertions(+), 15 deletions(-) diff --git a/elasticsearch/Cargo.toml b/elasticsearch/Cargo.toml index 337850e7..7926702a 100644 --- a/elasticsearch/Cargo.toml +++ b/elasticsearch/Cargo.toml @@ -36,6 +36,7 @@ url = "2" serde = { version = "1", features = ["derive"] } serde_json = "1" serde_with = "3" +tokio = { version = "1", default-features = false, features = ["macros", "net", "time", "rt-multi-thread"] } void = "1" [dev-dependencies] @@ -50,7 +51,6 @@ os_type = "2" regex="1" #sysinfo = "0.31" textwrap = "0.16" -tokio = { version = "1", default-features = false, features = ["macros", "net", "time", "rt-multi-thread"] } xml-rs = "0.8" [build-dependencies] diff --git a/elasticsearch/src/http/transport.rs b/elasticsearch/src/http/transport.rs index 880c1256..22100fe9 100644 --- a/elasticsearch/src/http/transport.rs +++ b/elasticsearch/src/http/transport.rs @@ -39,11 +39,16 @@ use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, write::Encode use bytes::BytesMut; use lazy_static::lazy_static; use serde::Serialize; +use serde_json::Value; use std::{ error, fmt, fmt::Debug, io::{self, Write}, - time::Duration, + sync::{ + atomic::{AtomicBool, AtomicUsize, Ordering}, + Arc, RwLock, + }, + time::{Duration, Instant}, }; use url::Url; @@ -310,7 +315,7 @@ impl TransportBuilder { let client = client_builder.build()?; Ok(Transport { client, - conn_pool: self.conn_pool, + conn_pool: Arc::new(self.conn_pool), credentials: self.credentials, send_meta: self.meta_header, }) @@ -327,7 +332,7 @@ impl Default for TransportBuilder { /// A connection to an Elasticsearch node, used to send an API request #[derive(Debug, Clone)] pub struct Connection { - url: Url, + url: Arc, } impl Connection { @@ -341,8 +346,14 @@ impl Connection { url.set_path(&format!("{}/", url.path())); } + let url = Arc::new(url); + Self { url } } + + pub fn url(&self) -> Arc { + self.url.clone() + } } /// A HTTP transport responsible for making the API requests to Elasticsearch, @@ -351,7 +362,7 @@ impl Connection { pub struct Transport { client: reqwest::Client, credentials: Option, - conn_pool: Box, + conn_pool: Arc>, send_meta: bool, } @@ -401,6 +412,35 @@ impl Transport { Ok(transport) } + /// Creates a new instance of a [Transport] configured with a + /// [MultiNodeConnectionPool] that does not refresh + pub fn static_node_list(urls: Vec<&str>) -> Result { + let urls: Vec = urls + .iter() + .map(|url| Url::parse(url)) + .collect::, _>>()?; + let conn_pool = MultiNodeConnectionPool::round_robin(urls, None); + let transport = TransportBuilder::new(conn_pool).build()?; + Ok(transport) + } + + /// Creates a new instance of a [Transport] configured with a + /// [MultiNodeConnectionPool] + /// + /// * `reseed_frequency` - frequency at which connections should be refreshed in seconds + pub fn sniffing_node_list( + urls: Vec<&str>, + reseed_frequency: Duration, + ) -> Result { + let urls: Vec = urls + .iter() + .map(|url| Url::parse(url)) + .collect::, _>>()?; + let conn_pool = MultiNodeConnectionPool::round_robin(urls, Some(reseed_frequency)); + let transport = TransportBuilder::new(conn_pool).build()?; + Ok(transport) + } + /// Creates a new instance of a [Transport] configured for use with /// [Elasticsearch service in Elastic Cloud](https://www.elastic.co/cloud/). /// @@ -413,23 +453,23 @@ impl Transport { Ok(transport) } - /// Creates an asynchronous request that can be awaited - pub async fn send( + #[allow(clippy::too_many_arguments)] + fn request_builder( &self, + connection: &Connection, method: Method, path: &str, headers: HeaderMap, query_string: Option<&Q>, body: Option, timeout: Option, - ) -> Result + ) -> Result where B: Body, Q: Serialize + ?Sized, { - let connection = self.conn_pool.next(); - let url = connection.url.join(path.trim_start_matches('/'))?; let reqwest_method = self.method(method); + let url = connection.url.join(path.trim_start_matches('/'))?; let mut request_builder = self.client.request(reqwest_method, url); if let Some(t) = timeout { @@ -493,6 +533,99 @@ impl Transport { if let Some(q) = query_string { request_builder = request_builder.query(q); } + Ok(request_builder) + } + + fn parse_to_url(address: &str, scheme: &str) -> Result { + if address.is_empty() { + return Err(crate::error::lib("Bound Address is empty")); + } + + let mut host_port = None; + if let Some((host, tail)) = address.split_once('/') { + if let Some((_, port)) = tail.rsplit_once(':') { + host_port = Some((host, port)); + } + } else { + host_port = address.rsplit_once(':'); + } + + let (host, port) = host_port.ok_or_else(|| { + crate::error::lib(format!("error parsing address into url: {}", address)) + })?; + + Ok(Url::parse( + format!("{}://{}:{}", scheme, host, port).as_str(), + )?) + } + + /// Creates an asynchronous request that can be awaited + pub async fn send( + &self, + method: Method, + path: &str, + headers: HeaderMap, + query_string: Option<&Q>, + body: Option, + timeout: Option, + ) -> Result + where + B: Body, + Q: Serialize + ?Sized, + { + // Requests will execute against old connection pool during reseed + if self.conn_pool.reseedable() { + let conn_pool = self.conn_pool.clone(); + let connection = conn_pool.next(); + + // Build node info request + let node_request = self.request_builder( + &connection, + Method::Get, + "_nodes/http?filter_path=nodes.*.http", + headers.clone(), + None::<&Q>, + None::, + timeout, + )?; + + tokio::spawn(async move { + let scheme = connection.url.scheme(); + let resp = node_request.send().await.unwrap(); + let json: Value = resp.json().await.unwrap(); + let connections: Vec = json["nodes"] + .as_object() + .unwrap() + .iter() + .map(|(_, node)| { + let address = node["http"]["publish_address"] + .as_str() + .or_else(|| { + Some( + node["http"]["bound_address"].as_array().unwrap()[0] + .as_str() + .unwrap(), + ) + }) + .unwrap(); + let url = Self::parse_to_url(address, scheme).unwrap(); + Connection::new(url) + }) + .collect(); + conn_pool.reseed(connections); + }); + } + + let connection = self.conn_pool.next(); + let request_builder = self.request_builder( + &connection, + method, + path, + headers, + query_string, + body, + timeout, + )?; let response = request_builder.send().await; match response { @@ -516,7 +649,14 @@ impl Default for Transport { /// dynamically at runtime, based upon the response to API calls. pub trait ConnectionPool: Debug + dyn_clone::DynClone + Sync + Send { /// Gets a reference to the next [Connection] - fn next(&self) -> &Connection; + fn next(&self) -> Connection; + + fn reseedable(&self) -> bool { + false + } + + // NOOP by default + fn reseed(&self, _connection: Vec) {} } clone_trait_object!(ConnectionPool); @@ -545,8 +685,8 @@ impl Default for SingleNodeConnectionPool { impl ConnectionPool for SingleNodeConnectionPool { /// Gets a reference to the next [Connection] - fn next(&self) -> &Connection { - &self.connection + fn next(&self) -> Connection { + self.connection.clone() } } @@ -667,8 +807,122 @@ impl CloudConnectionPool { impl ConnectionPool for CloudConnectionPool { /// Gets a reference to the next [Connection] - fn next(&self) -> &Connection { - &self.connection + fn next(&self) -> Connection { + self.connection.clone() + } +} + +/// A Connection Pool that manages a static connection of nodes +#[derive(Debug, Clone)] +pub struct MultiNodeConnectionPool { + inner: Arc>, + reseed_frequency: Option, + connection_selector: ConnSelector, + reseeding: Arc, +} + +#[derive(Debug, Clone)] +pub struct MultiNodeConnectionPoolInner { + last_update: Option, + connections: Vec, +} + +impl ConnectionPool for MultiNodeConnectionPool +where + ConnSelector: ConnectionSelector + Clone, +{ + fn next(&self) -> Connection { + let inner = self.inner.read().expect("lock poisoned"); + self.connection_selector + .try_next(&inner.connections) + .unwrap() + } + + fn reseedable(&self) -> bool { + let inner = self.inner.read().expect("lock poisoned"); + let reseed_frequency = match self.reseed_frequency { + Some(wait) => wait, + None => return false, + }; + let last_update_is_stale = inner + .last_update + .as_ref() + .map(|last_update| last_update.elapsed() > reseed_frequency); + let reseedable = last_update_is_stale.unwrap_or(true); + + if !reseedable { + false + } else { + // Check if refreshing is false if so, sets to true atomically and returns old value (false) meaning refreshable is true + // If refreshing is set to true, do nothing and return true, meaning refreshable is false + !self + .reseeding + .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) + // This can be replaced with `.into_ok_or_err` once stable. + // https://doc.rust-lang.org/std/result/enum.Result.html#method.into_ok_or_err + .unwrap_or(true) + } + } + + fn reseed(&self, mut connection: Vec) { + let mut inner = self.inner.write().expect("lock poisoned"); + inner.last_update = Some(Instant::now()); + inner.connections.clear(); + inner.connections.append(&mut connection); + self.reseeding.store(false, Ordering::Relaxed); + } +} + +impl MultiNodeConnectionPool { + /** Use a round-robin strategy for balancing traffic over the given set of nodes. */ + pub fn round_robin(urls: Vec, reseed_frequency: Option) -> Self { + let connections = urls.into_iter().map(Connection::new).collect(); + + let inner: Arc> = + Arc::new(RwLock::new(MultiNodeConnectionPoolInner { + last_update: None, + connections, + })); + let reseeding = Arc::new(AtomicBool::new(false)); + + let connection_selector = RoundRobin::default(); + Self { + inner, + connection_selector, + reseed_frequency, + reseeding, + } + } +} + +/** The strategy selects an address from a given collection. */ +pub trait ConnectionSelector: Send + Sync + Debug { + /** Try get the next connection. */ + fn try_next(&self, connections: &[Connection]) -> Result; +} + +/** A round-robin strategy cycles through nodes sequentially. */ +#[derive(Clone, Debug)] +pub struct RoundRobin { + index: Arc, +} + +impl Default for RoundRobin { + fn default() -> Self { + RoundRobin { + index: Arc::new(AtomicUsize::new(0)), + } + } +} + +impl ConnectionSelector for RoundRobin { + fn try_next(&self, connections: &[Connection]) -> Result { + if connections.is_empty() { + Err(crate::error::lib("Connection list empty")) + } else { + let i = self.index.fetch_add(1, Ordering::Relaxed) % connections.len(); + Ok(connections[i].clone()) + } } } @@ -677,7 +931,13 @@ pub mod tests { use super::*; #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] use crate::auth::ClientCertificate; + use crate::http::transport::{ + CloudId, Connection, ConnectionPool, MultiNodeConnectionPool, SingleNodeConnectionPool, + Transport, TransportBuilder, + }; use regex::Regex; + use std::sync::atomic::Ordering; + use std::time::{Duration, Instant}; use url::Url; #[test] @@ -716,6 +976,24 @@ pub mod tests { ); } + #[test] + fn test_url_parsing_where_hostname_and_ip_present() { + let url = Transport::parse_to_url("localhost/127.0.0.1:9200", "http").unwrap(); + assert_eq!(url, Url::parse("http://localhost:9200/").unwrap()); + } + + #[test] + fn test_url_parsing_where_only_ip_present() { + let url = Transport::parse_to_url("127.0.0.1:9200", "http").unwrap(); + assert_eq!(url, Url::parse("http://127.0.0.1:9200/").unwrap()); + } + + #[test] + fn test_url_parsing_where_only_hostname_present() { + let url = Transport::parse_to_url("localhost:9200", "http").unwrap(); + assert_eq!(url, Url::parse("http://localhost:9200/").unwrap()); + } + #[test] fn can_parse_cloud_id_without_kibana_uuid() { let base64 = @@ -825,4 +1103,87 @@ pub mod tests { println!("{}", x); assert!(re.is_match(x)); } + + fn expected_addresses() -> Vec { + vec!["http://a:9200/", "http://b:9200/", "http://c:9200/"] + .iter() + .map(|addr| Url::parse(addr).unwrap()) + .collect() + } + + #[test] + fn test_reseedable_false_on_no_duration() { + let connections = MultiNodeConnectionPool::round_robin(expected_addresses(), None); + assert!(!connections.reseedable()); + } + + #[test] + fn test_reseed() { + let connection_pool = + MultiNodeConnectionPool::round_robin(vec![], Some(Duration::from_secs(28800))); + + let connections = expected_addresses() + .into_iter() + .map(Connection::new) + .collect(); + connection_pool.reseed(connections); + for _ in 0..10 { + for expected in expected_addresses() { + let actual = connection_pool.next(); + + assert_eq!(expected.as_str(), actual.url.as_str()); + } + } + // Check connection pool not reseedable after reseed + assert!(!connection_pool.reseedable()); + assert!(!connection_pool.reseeding.load(Ordering::Relaxed)); + } + + #[test] + fn test_reseedable_after_duration() { + let connection_pool = MultiNodeConnectionPool::round_robin( + expected_addresses(), + Some(Duration::from_secs(30)), + ); + + // Set internal last_update to a minute ago + let mut inner = connection_pool.inner.write().expect("lock poisoned"); + inner.last_update = Some(Instant::now() - Duration::from_secs(60)); + drop(inner); + + assert!(connection_pool.reseedable()); + assert!(connection_pool.reseeding.load(Ordering::Relaxed)); + } + + #[test] + fn round_robin_next_multi() { + let connections = MultiNodeConnectionPool::round_robin(expected_addresses(), None); + + for _ in 0..10 { + for expected in expected_addresses() { + let actual = connections.next(); + + assert_eq!(expected.as_str(), actual.url.as_str()); + } + } + } + + #[test] + fn round_robin_next_single() { + let expected = Url::parse("http://a:9200/").unwrap(); + let connections = MultiNodeConnectionPool::round_robin(vec![expected.clone()], None); + + for _ in 0..10 { + let actual = connections.next(); + + assert_eq!(expected.as_str(), actual.url.as_str()); + } + } + + #[test] + #[should_panic] + fn round_robin_next_empty_fails() { + let connections = MultiNodeConnectionPool::round_robin(vec![], None); + connections.next(); + } }