diff --git a/xline/src/server/mod.rs b/xline/src/server/mod.rs index f50459b468..da94d0302a 100644 --- a/xline/src/server/mod.rs +++ b/xline/src/server/mod.rs @@ -17,6 +17,8 @@ mod watch_server; /// Xline server mod xline_server; -pub use self::command::{Command, KeyRange}; pub(crate) use self::maintenance::MAINTENANCE_SNAPSHOT_CHUNK_SIZE; -pub use self::xline_server::XlineServer; +pub use self::{ + command::{Command, KeyRange}, + xline_server::XlineServer, +}; diff --git a/xline/src/server/watch_server.rs b/xline/src/server/watch_server.rs index 8669962517..c625cbb258 100644 --- a/xline/src/server/watch_server.rs +++ b/xline/src/server/watch_server.rs @@ -83,6 +83,8 @@ where header_gen, ); let mut ticker = tokio::time::interval(watch_progress_notify_interval); + let stop_listener = stop_notify.listen(); + tokio::pin!(stop_listener); loop { tokio::select! { req = req_rx.next() => { @@ -111,7 +113,9 @@ where _ = ticker.tick() => { watch_handle.handle_tick_progress().await; } - _ = stop_notify.listen() => { + // To ensure that each iteration invokes the same `stop_listener` and keeps + // events losing due to the cancellation of `stop_listener` at bay. + _ = &mut stop_listener => { break; } } @@ -666,4 +670,49 @@ mod test { handle.await.unwrap(); Ok(()) } + + #[tokio::test] + async fn test_watch_task_should_be_terminated_when_response_tx_is_closed( + ) -> Result<(), Box> { + let (req_tx, req_rx) = mpsc::channel(CHANNEL_SIZE); + let (res_tx, res_rx) = mpsc::channel(CHANNEL_SIZE); + let req_stream: ReceiverStream> = + ReceiverStream::new(req_rx); + let header_gen = Arc::new(HeaderGenerator::new(0, 0)); + let mut mock_watcher = MockKvWatcherOps::new(); + let _ = mock_watcher.expect_watch().times(1).return_const(()); + let _ = mock_watcher.expect_cancel().times(1).return_const(()); + let watcher = Arc::new(mock_watcher); + let next_id = Arc::new(WatchIdGenerator::new(1)); + let handle = tokio::spawn(WatchServer::::task( + next_id, + Arc::clone(&watcher), + res_tx, + req_stream, + header_gen, + Duration::from_millis(100), + )); + + req_tx + .send(Ok(WatchRequest { + request_union: Some(RequestUnion::CreateRequest(WatchCreateRequest { + key: "foo".into(), + progress_notify: true, + watch_id: 1, + ..Default::default() + })), + })) + .await?; + + drop(res_rx); + + req_tx + .send(Ok(WatchRequest { + request_union: Some(RequestUnion::ProgressRequest(WatchProgressRequest {})), + })) + .await?; + + assert!(timeout(Duration::from_secs(10), handle).await.is_ok()); + Ok(()) + } }