diff --git a/async-nats/src/jetstream/consumer/pull.rs b/async-nats/src/jetstream/consumer/pull.rs index e28e167ad..1c7938804 100644 --- a/async-nats/src/jetstream/consumer/pull.rs +++ b/async-nats/src/jetstream/consumer/pull.rs @@ -12,21 +12,15 @@ // limitations under the License. use bytes::Bytes; -use futures::{future::BoxFuture, FutureExt, StreamExt, TryFutureExt}; +use futures::{ + future::{BoxFuture, Pending}, + FutureExt, StreamExt, TryFutureExt, +}; #[cfg(feature = "server_2_10")] use std::collections::HashMap; -use std::{ - future, - pin::Pin, - sync::{Arc, Mutex}, - task::Poll, - time::Duration, -}; -use tokio::{ - task::JoinHandle, - time::{Instant, Sleep}, -}; +use std::{future, pin::Pin, task::Poll, time::Duration}; +use tokio::{task::JoinHandle, time::Sleep}; use serde::{Deserialize, Serialize}; use tracing::{debug, trace}; @@ -839,18 +833,13 @@ pub struct Stream { context: Context, pending_request: bool, task_handle: JoinHandle<()>, - heartbeat_handle: Option>, - last_seen: Arc>, - heartbeats_missing: tokio::sync::mpsc::Receiver<()>, terminated: bool, + heartbeat_timeout: Option>>>>, } impl Drop for Stream { fn drop(&mut self) { self.task_handle.abort(); - if let Some(handle) = self.heartbeat_handle.take() { - handle.abort() - } } } @@ -942,47 +931,11 @@ impl Stream { .unwrap(); trace!("result send over tx"); } - // } } }); - let last_seen = Arc::new(Mutex::new(Instant::now())); - let (missed_heartbeat_tx, missed_heartbeat_rx) = tokio::sync::mpsc::channel(1); - let heartbeat_handle = if !batch_config.idle_heartbeat.is_zero() { - debug!("spawning heartbeat checker task"); - Some(tokio::task::spawn({ - let last_seen = last_seen.clone(); - async move { - loop { - tokio::time::sleep(batch_config.idle_heartbeat).await; - debug!("checking for missed heartbeats"); - let should_reset = { - let mut last_seen = last_seen.lock().unwrap(); - if last_seen - .elapsed() - .gt(&batch_config.idle_heartbeat.saturating_mul(2)) - { - // If we met the missed heartbeat threshold, reset the timer - // so it will not be instantly triggered again. - *last_seen = Instant::now(); - true - } else { - false - } - }; - if should_reset { - debug!("missed heartbeat threshold met"); - missed_heartbeat_tx.send(()).await.unwrap(); - } - } - } - })) - } else { - None - }; Ok(Stream { task_handle, - heartbeat_handle, request_result_rx, request_tx, batch_config, @@ -991,9 +944,8 @@ impl Stream { subscriber: subscription, context: consumer.context.clone(), pending_request: false, - last_seen, - heartbeats_missing: missed_heartbeat_rx, terminated: false, + heartbeat_timeout: None, }) } } @@ -1095,6 +1047,28 @@ impl futures::Stream for Stream { if self.terminated { return Poll::Ready(None); } + + if !self.batch_config.idle_heartbeat.is_zero() { + trace!("setting hearbeats"); + let timeout = self.batch_config.idle_heartbeat.saturating_mul(2); + self.heartbeat_timeout.get_or_insert_with(|| { + Box::pin(tokio::time::timeout(timeout, futures::future::pending())) + }); + + trace!("checking idle hearbeats"); + if let Some(hearbeat) = self.heartbeat_timeout.as_mut() { + match hearbeat.try_poll_unpin(cx) { + Poll::Ready(_) => { + self.heartbeat_timeout = None; + return Poll::Ready(Some(Err(MessagesError::new( + MessagesErrorKind::MissingHeartbeat, + )))); + } + Poll::Pending => (), + } + } + } + loop { trace!("pending messages: {}", self.pending_messages); if (self.pending_messages <= self.batch_config.batch / 2 @@ -1106,28 +1080,7 @@ impl futures::Stream for Stream { self.request_tx.send(()).unwrap(); self.pending_request = true; } - if self.heartbeat_handle.is_some() { - match self.heartbeats_missing.poll_recv(cx) { - Poll::Ready(resp) => match resp { - Some(()) => { - trace!("received missing heartbeats notification"); - return Poll::Ready(Some(Err(MessagesError::new( - MessagesErrorKind::MissingHeartbeat, - )))); - } - None => { - self.terminated = true; - return Poll::Ready(Some(Err(MessagesError::with_source( - MessagesErrorKind::Other, - "unexpected termination of heartbeat checker", - )))); - } - }, - Poll::Pending => { - trace!("pending message from missing heartbeats notification channel"); - } - } - } + match self.request_result_rx.poll_recv(cx) { Poll::Ready(resp) => match resp { Some(resp) => match resp { @@ -1157,102 +1110,96 @@ impl futures::Stream for Stream { trace!("pending result"); } } + trace!("polling subscriber"); match self.subscriber.receiver.poll_recv(cx) { - Poll::Ready(maybe_message) => match maybe_message { - Some(message) => match message.status.unwrap_or(StatusCode::OK) { - StatusCode::TIMEOUT | StatusCode::REQUEST_TERMINATED => { - debug!("received status message: {:?}", message); - // If consumer has been deleted, error and shutdown the iterator. - if message.description.as_deref() == Some("Consumer Deleted") { - self.terminated = true; - return Poll::Ready(Some(Err(MessagesError::new( - MessagesErrorKind::ConsumerDeleted, - )))); + Poll::Ready(maybe_message) => { + self.heartbeat_timeout = None; + match maybe_message { + Some(message) => match message.status.unwrap_or(StatusCode::OK) { + StatusCode::TIMEOUT | StatusCode::REQUEST_TERMINATED => { + debug!("received status message: {:?}", message); + // If consumer has been deleted, error and shutdown the iterator. + if message.description.as_deref() == Some("Consumer Deleted") { + self.terminated = true; + return Poll::Ready(Some(Err(MessagesError::new( + MessagesErrorKind::ConsumerDeleted, + )))); + } + // If consumer is not pull based, error and shutdown the iterator. + if message.description.as_deref() == Some("Consumer is push based") + { + self.terminated = true; + return Poll::Ready(Some(Err(MessagesError::new( + MessagesErrorKind::PushBasedConsumer, + )))); + } + // All other cases can be handled. + + // Do accounting for messages left after terminated/completed pull request. + let pending_messages = message + .headers + .as_ref() + .and_then(|headers| headers.get("Nats-Pending-Messages")) + .map(|h| h.iter()) + .and_then(|mut i| i.next()) + .map(|e| e.parse::()) + .unwrap_or(Ok(self.batch_config.batch)) + .map_err(|err| { + MessagesError::with_source(MessagesErrorKind::Other, err) + })?; + let pending_bytes = message + .headers + .as_ref() + .and_then(|headers| headers.get("Nats-Pending-Bytes")) + .map(|h| h.iter()) + .and_then(|mut i| i.next()) + .map(|e| e.parse::()) + .unwrap_or(Ok(self.batch_config.max_bytes)) + .map_err(|err| { + MessagesError::with_source(MessagesErrorKind::Other, err) + })?; + debug!( + "timeout reached. remaining messages: {}, bytes {}", + pending_messages, pending_bytes + ); + self.pending_messages = + self.pending_messages.saturating_sub(pending_messages); + trace!("message bytes len: {}", pending_bytes); + self.pending_bytes = + self.pending_bytes.saturating_sub(pending_bytes); + continue; } - // If consumer is not pull based, error and shutdown the iterator. - if message.description.as_deref() == Some("Consumer is push based") { - self.terminated = true; - return Poll::Ready(Some(Err(MessagesError::new( - MessagesErrorKind::PushBasedConsumer, - )))); + // Idle Hearbeat means we have no messages, but consumer is fine. + StatusCode::IDLE_HEARTBEAT => { + debug!("received idle heartbeat"); + continue; } - // All other cases can be handled. - - // Got a status message from a consumer, meaning it's alive. - // Update last seen. - if !self.batch_config.idle_heartbeat.is_zero() { - *self.last_seen.lock().unwrap() = Instant::now(); - } - - // Do accounting for messages left after terminated/completed pull request. - let pending_messages = message - .headers - .as_ref() - .and_then(|headers| headers.get("Nats-Pending-Messages")) - .map(|h| h.iter()) - .and_then(|mut i| i.next()) - .map(|e| e.parse::()) - .unwrap_or(Ok(self.batch_config.batch)) - .map_err(|err| { - MessagesError::with_source(MessagesErrorKind::Other, err) - })?; - let pending_bytes = message - .headers - .as_ref() - .and_then(|headers| headers.get("Nats-Pending-Bytes")) - .map(|h| h.iter()) - .and_then(|mut i| i.next()) - .map(|e| e.parse::()) - .unwrap_or(Ok(self.batch_config.max_bytes)) - .map_err(|err| { - MessagesError::with_source(MessagesErrorKind::Other, err) - })?; - debug!( - "timeout reached. remaining messages: {}, bytes {}", - pending_messages, pending_bytes - ); - self.pending_messages = - self.pending_messages.saturating_sub(pending_messages); - trace!("message bytes len: {}", pending_bytes); - self.pending_bytes = self.pending_bytes.saturating_sub(pending_bytes); - continue; - } - // Idle Hearbeat means we have no messages, but consumer is fine. - StatusCode::IDLE_HEARTBEAT => { - debug!("received idle heartbeat"); - if !self.batch_config.idle_heartbeat.is_zero() { - *self.last_seen.lock().unwrap() = Instant::now(); + // We got an message from a stream. + StatusCode::OK => { + trace!("message received"); + self.pending_messages = self.pending_messages.saturating_sub(1); + self.pending_bytes = + self.pending_bytes.saturating_sub(message.length); + return Poll::Ready(Some(Ok(jetstream::Message { + context: self.context.clone(), + message, + }))); } - continue; - } - // We got an message from a stream. - StatusCode::OK => { - trace!("message received"); - if !self.batch_config.idle_heartbeat.is_zero() { - *self.last_seen.lock().unwrap() = Instant::now(); + status => { + debug!("received unknown message: {:?}", message); + return Poll::Ready(Some(Err(MessagesError::with_source( + MessagesErrorKind::Other, + format!( + "error while processing messages from the stream: {}, {:?}", + status, message.description + ), + )))); } - *self.last_seen.lock().unwrap() = Instant::now(); - self.pending_messages = self.pending_messages.saturating_sub(1); - self.pending_bytes = self.pending_bytes.saturating_sub(message.length); - return Poll::Ready(Some(Ok(jetstream::Message { - context: self.context.clone(), - message, - }))); - } - status => { - debug!("received unknown message: {:?}", message); - return Poll::Ready(Some(Err(MessagesError::with_source( - MessagesErrorKind::Other, - format!( - "error while processing messages from the stream: {}, {:?}", - status, message.description - ), - )))); - } - }, - None => return Poll::Ready(None), - }, + }, + None => return Poll::Ready(None), + } + } Poll::Pending => { debug!("subscriber still pending"); return std::task::Poll::Pending; diff --git a/async-nats/tests/jetstream_tests.rs b/async-nats/tests/jetstream_tests.rs index c5fa85ebd..369fe5d11 100644 --- a/async-nats/tests/jetstream_tests.rs +++ b/async-nats/tests/jetstream_tests.rs @@ -2109,7 +2109,11 @@ mod jetstream { #[cfg(feature = "slow_tests")] #[tokio::test] async fn pull_consumer_stream_with_heartbeat() { - use tracing::debug; + tracing_subscriber::fmt() + .with_max_level(Level::DEBUG) + .init(); + + use tracing::{debug, Level}; let server = nats_server::run_server("tests/configs/jetstream.conf"); let client = ConnectOptions::new() .event_callback(|err| async move { println!("error: {err:?}") }) @@ -2175,7 +2179,10 @@ mod jetstream { .unwrap(); // and expect the message to be there. debug!("awaiting the message with recreated consumer"); - messages.next().await.unwrap().unwrap(); + let now = Instant::now(); + let m = messages.next().await.unwrap(); + println!("after: {:?}", now.elapsed()); + m.unwrap(); } #[tokio::test] @@ -3219,7 +3226,12 @@ mod jetstream { .unwrap(); for _ in 0..10 { - context.publish("test".into(), "data".into()).await.unwrap(); + context + .publish("test".into(), "data".into()) + .await + .unwrap() + .await + .unwrap(); } let consumer = stream