From ea1575ac92c59c43736ab7c4714e112f77630274 Mon Sep 17 00:00:00 2001 From: Carter Green Date: Fri, 31 May 2024 07:35:23 -0500 Subject: [PATCH] ADD: Add configurable heartbeat support to clients --- CHANGELOG.md | 20 ++++++- src/live.rs | 78 +++++++++++++++++++++++--- src/live/client.rs | 136 ++++++++++++++++++++++++++++++--------------- 3 files changed, 178 insertions(+), 56 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 151efb6..277bd3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,19 @@ # Changelog +## 0.11.0 - TBD + +#### Enhancements +- Added configurable `heartbeat_interval` parameter for live client that determines the + timeout before heartbeat `SystemMsg` records will be sent. It can be configured via + the `heartbeat_interval` and `heartbeat_interval_s` methods of the + `live::ClientBuilder` +- Added `addr` function to `live::ClientBuilder` for configuring a custom gateway + address without using `LiveClient::connect_with_addr` directly + +#### Breaking changes +- Added `heartbeat_interval` parameter to `LiveClient::connect` and + `LiveClient::connect_with_addr` + ## 0.10.0 - 2024-05-22 #### Enhancements @@ -56,10 +70,10 @@ - Document `live::Subscription::start` is based on `ts_event` - Allow constructing a `DateRange` and `DateTimeRange` with an `end` based on a `time::Duration` -- Implemented `Debug` for `LiveClient`, `LiveClientBuilder`, `HistoricalClient`, - `HistoricalClientBuilder`, `BatchClient`, `MetadataClient`, `SymbologyClient`, and +- Implemented `Debug` for `LiveClient`, `live::ClientBuilder`, `HistoricalClient`, + `historical::ClientBuilder`, `BatchClient`, `MetadataClient`, `SymbologyClient`, and `TimeseriesClient` -- Derived `Clone` for `LiveClientBuilder` and `HistoricalClientBuilder` +- Derived `Clone` for `live::ClientBuilder` and `historical::ClientBuilder` - Added `ApiKey` type for safely deriving `Debug` for types containing an API key #### Breaking changes diff --git a/src/live.rs b/src/live.rs index 599460e..7bb493d 100644 --- a/src/live.rs +++ b/src/live.rs @@ -2,8 +2,12 @@ mod client; +use std::{net::SocketAddr, sync::Arc}; + use dbn::{SType, Schema, VersionUpgradePolicy}; -use time::OffsetDateTime; +use log::warn; +use time::{Duration, OffsetDateTime}; +use tokio::net::{lookup_host, ToSocketAddrs}; use typed_builder::TypedBuilder; use crate::{ApiKey, Symbols}; @@ -43,19 +47,23 @@ pub struct Unset; /// - `dataset` #[derive(Debug, Clone)] pub struct ClientBuilder { + addr: Option>, key: AK, dataset: D, send_ts_out: bool, upgrade_policy: VersionUpgradePolicy, + heartbeat_interval: Option, } impl Default for ClientBuilder { fn default() -> Self { Self { + addr: None, key: Unset, dataset: Unset, send_ts_out: false, upgrade_policy: VersionUpgradePolicy::Upgrade, + heartbeat_interval: None, } } } @@ -74,6 +82,43 @@ impl ClientBuilder { self.upgrade_policy = upgrade_policy; self } + + /// Sets `heartbeat_interval`, which controls the interval at which the gateway + /// will send heartbeat records if no other data records are sent. If no heartbeat + /// interval is configured, the gateway default will be used. + /// + /// Note that granularity of less than a second is not supported and will be + /// ignored. + pub fn heartbeat_interval(mut self, heartbeat_interval: Duration) -> Self { + if heartbeat_interval.subsec_nanoseconds() > 0 { + warn!( + "heartbeat_interval subsecond precision ignored: {}ns", + heartbeat_interval.subsec_nanoseconds() + ) + } + self.heartbeat_interval = Some(heartbeat_interval); + self + } + + /// Overrides the address of the gateway the client will connect to. This is an + /// advanced method. + /// + /// # Errors + /// This function returns an error when `addr` fails to resolve. + pub async fn addr(mut self, addr: impl ToSocketAddrs) -> crate::Result { + const PARAM_NAME: &str = "addr"; + self.addr = Some( + lookup_host(addr) + .await + .map_err(|e| crate::Error::bad_arg(PARAM_NAME, format!("{e}")))? + .next() + .map(Arc::new) + .ok_or_else(|| { + crate::Error::bad_arg(PARAM_NAME, "did not resolve to any host".to_owned()) + })?, + ); + Ok(self) + } } impl ClientBuilder { @@ -90,10 +135,12 @@ impl ClientBuilder { /// This function returns an error when the API key is invalid. pub fn key(self, key: impl ToString) -> crate::Result> { Ok(ClientBuilder { + addr: self.addr, key: crate::validate_key(key.to_string())?, dataset: self.dataset, send_ts_out: self.send_ts_out, upgrade_policy: self.upgrade_policy, + heartbeat_interval: self.heartbeat_interval, }) } @@ -113,10 +160,12 @@ impl ClientBuilder { /// Sets the dataset. pub fn dataset(self, dataset: impl ToString) -> ClientBuilder { ClientBuilder { + addr: self.addr, key: self.key, dataset: dataset.to_string(), send_ts_out: self.send_ts_out, upgrade_policy: self.upgrade_policy, + heartbeat_interval: self.heartbeat_interval, } } } @@ -128,12 +177,25 @@ impl ClientBuilder { /// This function returns an error when its unable /// to connect and authenticate with the Live gateway. pub async fn build(self) -> crate::Result { - Client::connect( - self.key.0, - self.dataset, - self.send_ts_out, - self.upgrade_policy, - ) - .await + if let Some(addr) = self.addr { + Client::connect_with_addr( + *addr, + self.key.0, + self.dataset, + self.send_ts_out, + self.upgrade_policy, + self.heartbeat_interval, + ) + .await + } else { + Client::connect( + self.key.0, + self.dataset, + self.send_ts_out, + self.upgrade_policy, + self.heartbeat_interval, + ) + .await + } } } diff --git a/src/live/client.rs b/src/live/client.rs index 3a2ad17..855f7f0 100644 --- a/src/live/client.rs +++ b/src/live/client.rs @@ -7,6 +7,7 @@ use dbn::{ use hex::ToHex; use log::{debug, error, info}; use sha2::{Digest, Sha256}; +use time::Duration; use tokio::{ io::{AsyncBufReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}, net::{TcpStream, ToSocketAddrs}, @@ -40,13 +41,16 @@ impl Client { /// Creates a new client connected to a Live gateway. /// /// # Errors - /// This function returns an error when `key` is invalid or its unable to connect + /// This function returns an error when `key` is invalid or it's unable to connect /// and authenticate with the Live gateway. + /// This function returns an error when `key` or `heartbeat_interval` are invalid, + /// or it's unable to connect and authenticate with the Live gateway. pub async fn connect( key: String, dataset: String, send_ts_out: bool, upgrade_policy: VersionUpgradePolicy, + heartbeat_interval: Option, ) -> crate::Result { Self::connect_with_addr( Self::determine_gateway(&dataset), @@ -54,6 +58,7 @@ impl Client { dataset, send_ts_out, upgrade_policy, + heartbeat_interval, ) .await } @@ -62,14 +67,15 @@ impl Client { /// [`builder()`](Self::builder) or [`connect()`](Self::connect) should be used instead. /// /// # Errors - /// This function returns an error when `key` is invalid or its unable - /// to connect and authenticate with the Live gateway. + /// This function returns an error when `key` or `heartbeat_interval` are invalid, + /// or it's unable to connect and authenticate with the Live gateway. pub async fn connect_with_addr( addr: impl ToSocketAddrs, key: String, dataset: String, send_ts_out: bool, upgrade_policy: VersionUpgradePolicy, + heartbeat_interval: Option, ) -> crate::Result { let key = validate_key(key)?; let stream = TcpStream::connect(addr).await?; @@ -77,8 +83,15 @@ impl Client { let mut reader = BufReader::new(reader); // Authenticate CRAM - let session_id = - Self::cram_challenge(&mut reader, &mut writer, &key.0, &dataset, send_ts_out).await?; + let session_id = Self::cram_challenge( + &mut reader, + &mut writer, + &key.0, + &dataset, + send_ts_out, + heartbeat_interval, + ) + .await?; Ok(Self { key, @@ -147,6 +160,7 @@ impl Client { key: &str, dataset: &str, send_ts_out: bool, + heartbeat_interval: Option, ) -> crate::Result { let mut greeting = String::new(); // Greeting @@ -179,8 +193,13 @@ impl Client { let bucket_id = &key[API_KEY_LENGTH - BUCKET_ID_LENGTH..]; let encoded_response = result.encode_hex::(); let send_ts_out = send_ts_out as i32; - let reply = - format!("auth={encoded_response}-{bucket_id}|dataset={dataset}|encoding=dbn|ts_out={send_ts_out}|client=Rust {}\n", env!("CARGO_PKG_VERSION")); + let mut reply = + format!("auth={encoded_response}-{bucket_id}|dataset={dataset}|encoding=dbn|ts_out={send_ts_out}|client=Rust {}", env!("CARGO_PKG_VERSION")); + if let Some(heartbeat_interval_s) = heartbeat_interval.map(|i| i.whole_seconds()) { + reply = format!("{reply}|heartbeat_interval_s={heartbeat_interval_s}\n") + } else { + reply.push('\n'); + } // Send CRAM reply debug!( @@ -339,7 +358,7 @@ impl fmt::Debug for Client { #[cfg(test)] #[allow(deprecated)] mod tests { - use std::{ffi::c_char, fmt, time::Duration}; + use std::{ffi::c_char, fmt}; use dbn::{ encode::AsyncDbnMetadataEncoder, @@ -348,6 +367,7 @@ mod tests { record::{HasRType, OhlcvMsg, RecordHeader, TradeMsg, WithTsOut}, FlagSet, Mbp10Msg, MetadataBuilder, Record, SType, Schema, }; + use time::Duration; use tokio::{ io::BufReader, join, @@ -383,7 +403,7 @@ mod tests { self.stream = Some(BufReader::new(stream)); } - async fn authenticate(&mut self) { + async fn authenticate(&mut self, heartbeat_interval: Option) { self.accept().await; self.send("lsg-test\n").await; self.send("cram=t7kNhwj4xqR0QYjzFKtBEG2ec2pXJ4FK\n").await; @@ -401,6 +421,14 @@ mod tests { assert!(auth_line.contains("encoding=dbn")); assert!(auth_line.contains(&format!("ts_out={}", if self.send_ts_out { 1 } else { 0 }))); assert!(auth_line.contains(&format!("client=Rust {}", env!("CARGO_PKG_VERSION")))); + if let Some(heartbeat_interval) = heartbeat_interval { + assert!(auth_line.contains(&format!( + "heartbeat_interval_s={}", + heartbeat_interval.whole_seconds() + ))); + } else { + assert!(!auth_line.contains("heartbeat_interval_s=")); + } self.send("success=1|session_id=5\n").await; } @@ -475,7 +503,7 @@ mod tests { enum Event { Stop, Accept, - Authenticate, + Authenticate(Option), Send(String), Subscribe(Subscription), Start, @@ -487,7 +515,7 @@ mod tests { match self { Event::Stop => write!(f, "Stop"), Event::Accept => write!(f, "Accept"), - Event::Authenticate => write!(f, "Authenticate"), + Event::Authenticate(hb_int) => write!(f, "Authenticate({hb_int:?})"), Event::Send(msg) => write!(f, "Send({msg:?})"), Event::Subscribe(sub) => write!(f, "Subscribe({sub:?})"), Event::Start => write!(f, "Start"), @@ -504,7 +532,7 @@ mod tests { let task = tokio::task::spawn(async move { loop { match recv.recv().await { - Some(Event::Authenticate) => mock.authenticate().await, + Some(Event::Authenticate(hb_int)) => mock.authenticate(hb_int).await, Some(Event::Accept) => mock.accept().await, Some(Event::Send(msg)) => mock.send(&msg).await, Some(Event::Subscribe(sub)) => mock.subscribe(sub).await, @@ -523,8 +551,10 @@ mod tests { } /// Accept and authenticate - pub fn authenticate(&mut self) { - self.send.send(Event::Authenticate).unwrap(); + pub fn authenticate(&mut self, heartbeat_interval: Option) { + self.send + .send(Event::Authenticate(heartbeat_interval)) + .unwrap(); } pub fn expect_subscribe(&mut self, subscription: Subscription) { @@ -554,17 +584,28 @@ mod tests { } } - async fn setup(dataset: Dataset, send_ts_out: bool) -> (Fixture, Client) { + async fn setup( + dataset: Dataset, + send_ts_out: bool, + heartbeat_interval: Option, + ) -> (Fixture, Client) { let _ = env_logger::try_init(); let mut fixture = Fixture::new(dataset.to_string(), send_ts_out).await; - fixture.authenticate(); - let target = Client::connect_with_addr( - format!("127.0.0.1:{}", fixture.port), - "32-character-with-lots-of-filler".to_owned(), - dataset.to_string(), - send_ts_out, - VersionUpgradePolicy::AsIs, - ) + fixture.authenticate(heartbeat_interval); + let builder = Client::builder() + .addr(format!("127.0.0.1:{}", fixture.port)) + .await + .unwrap() + .key("32-character-with-lots-of-filler".to_owned()) + .unwrap() + .dataset(dataset.to_string()) + .send_ts_out(send_ts_out); + let target = if let Some(heartbeat_interval) = heartbeat_interval { + builder.heartbeat_interval(heartbeat_interval) + } else { + builder + } + .build() .await .unwrap(); (fixture, target) @@ -572,7 +613,7 @@ mod tests { #[tokio::test] async fn test_subscribe() { - let (mut fixture, mut client) = setup(Dataset::XnasItch, false).await; + let (mut fixture, mut client) = setup(Dataset::XnasItch, false, None).await; let subscription = Subscription::builder() .symbols(vec!["MSFT", "TSLA", "QQQ"]) .schema(Schema::Ohlcv1M) @@ -585,7 +626,8 @@ mod tests { #[tokio::test] async fn test_subscribe_snapshot() { - let (mut fixture, mut client) = setup(Dataset::XnasItch, false).await; + let (mut fixture, mut client) = + setup(Dataset::XnasItch, false, Some(Duration::MINUTE)).await; let subscription = Subscription::builder() .symbols(vec!["MSFT", "TSLA", "QQQ"]) .schema(Schema::Ohlcv1M) @@ -599,7 +641,8 @@ mod tests { #[tokio::test] async fn test_subscribe_snapshot_failed() { - let (fixture, mut client) = setup(Dataset::XnasItch, false).await; + let (fixture, mut client) = + setup(Dataset::XnasItch, false, Some(Duration::seconds(5))).await; let err = client .subscribe( @@ -624,7 +667,7 @@ mod tests { async fn test_subscription_chunking() { const SYMBOL: &str = "TEST"; const SYMBOL_COUNT: usize = 1000; - let (mut fixture, mut client) = setup(Dataset::XnasItch, false).await; + let (mut fixture, mut client) = setup(Dataset::XnasItch, false, None).await; let sub_base = Subscription::builder() .schema(Schema::Ohlcv1M) .stype_in(SType::RawSymbol); @@ -649,7 +692,8 @@ mod tests { close: 4, volume: 5, }; - let (mut fixture, mut client) = setup(Dataset::GlbxMdp3, false).await; + let (mut fixture, mut client) = + setup(Dataset::GlbxMdp3, false, Some(Duration::minutes(5))).await; fixture.start(); let metadata = client.start().await.unwrap(); assert_eq!(metadata.version, dbn::DBN_VERSION); @@ -678,7 +722,7 @@ mod tests { }, time::OffsetDateTime::now_utc().unix_timestamp_nanos() as u64, ); - let (mut fixture, mut client) = setup(Dataset::GlbxMdp3, true).await; + let (mut fixture, mut client) = setup(Dataset::GlbxMdp3, true, None).await; fixture.start(); let metadata = client.start().await.unwrap(); assert_eq!(metadata.version, dbn::DBN_VERSION); @@ -692,7 +736,8 @@ mod tests { #[tokio::test] async fn test_close() { - let (mut fixture, mut client) = setup(Dataset::GlbxMdp3, true).await; + let (mut fixture, mut client) = + setup(Dataset::GlbxMdp3, true, Some(Duration::seconds(45))).await; fixture.start(); client.start().await.unwrap(); client.close().await.unwrap(); @@ -704,14 +749,15 @@ mod tests { const DATASET: Dataset = Dataset::OpraPillar; let mut fixture = Fixture::new(DATASET.to_string(), false).await; let client_task = tokio::spawn(async move { - let res = Client::connect_with_addr( - format!("127.0.0.1:{}", fixture.port), - "32-character-with-lots-of-filler".to_owned(), - DATASET.to_string(), - false, - VersionUpgradePolicy::AsIs, - ) - .await; + let res = Client::builder() + .addr(format!("127.0.0.1:{}", fixture.port)) + .await + .unwrap() + .key("32-character-with-lots-of-filler".to_owned()) + .unwrap() + .dataset(DATASET.to_string()) + .build() + .await; if let Err(e) = &res { dbg!(e); } @@ -731,7 +777,7 @@ mod tests { #[tokio::test] async fn test_cancellation_safety() { - let (mut fixture, mut client) = setup(Dataset::GlbxMdp3, true).await; + let (mut fixture, mut client) = setup(Dataset::GlbxMdp3, true, None).await; fixture.start(); let metadata = client.start().await.unwrap(); assert_eq!(metadata.version, dbn::DBN_VERSION); @@ -739,12 +785,12 @@ mod tests { assert_eq!(metadata.dataset, Dataset::GlbxMdp3.as_str()); fixture.send_record(Mbp10Msg::default()); - let mut int_1 = tokio::time::interval(Duration::from_millis(1)); - let mut int_2 = tokio::time::interval(Duration::from_millis(1)); - let mut int_3 = tokio::time::interval(Duration::from_millis(1)); - let mut int_4 = tokio::time::interval(Duration::from_millis(1)); - let mut int_5 = tokio::time::interval(Duration::from_millis(1)); - let mut int_6 = tokio::time::interval(Duration::from_millis(1)); + let mut int_1 = tokio::time::interval(std::time::Duration::from_millis(1)); + let mut int_2 = tokio::time::interval(std::time::Duration::from_millis(1)); + let mut int_3 = tokio::time::interval(std::time::Duration::from_millis(1)); + let mut int_4 = tokio::time::interval(std::time::Duration::from_millis(1)); + let mut int_5 = tokio::time::interval(std::time::Duration::from_millis(1)); + let mut int_6 = tokio::time::interval(std::time::Duration::from_millis(1)); for _ in 0..5_000 { select! { _ = int_1.tick() => {