Skip to content

Commit

Permalink
fix: avoid losing events when using event_lisenter and tokio::select!…
Browse files Browse the repository at this point in the history
… together

Refs: xline-kv#338, xline-kv#339
Closes: xline-kv#338, xline-kv#339
  • Loading branch information
Phoenix500526 committed Jun 22, 2023
1 parent 90ec64f commit 06c9cc8
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 3 deletions.
6 changes: 4 additions & 2 deletions xline/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ mod watch_server;
/// Xline server
mod xline_server;

pub use self::command::{Command, KeyRange};
pub(crate) use self::maintenance::MAINTENANCE_SNAPSHOT_CHUNK_SIZE;
pub use self::xline_server::XlineServer;
pub use self::{
command::{Command, KeyRange},
xline_server::XlineServer,
};
51 changes: 50 additions & 1 deletion xline/src/server/watch_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ where
header_gen,
);
let mut ticker = tokio::time::interval(watch_progress_notify_interval);
let stop_listener = stop_notify.listen();
tokio::pin!(stop_listener);
loop {
tokio::select! {
req = req_rx.next() => {
Expand Down Expand Up @@ -111,7 +113,9 @@ where
_ = ticker.tick() => {
watch_handle.handle_tick_progress().await;
}
_ = stop_notify.listen() => {
// To ensure that each iteration invokes the same `stop_listener` and keeps
// events losing due to the cancellation of `stop_listener` at bay.
_ = &mut stop_listener => {
break;
}
}
Expand Down Expand Up @@ -666,4 +670,49 @@ mod test {
handle.await.unwrap();
Ok(())
}

#[tokio::test]
async fn test_watch_task_should_be_terminated_when_response_tx_is_closed(
) -> Result<(), Box<dyn std::error::Error>> {
let (req_tx, req_rx) = mpsc::channel(CHANNEL_SIZE);
let (res_tx, res_rx) = mpsc::channel(CHANNEL_SIZE);
let req_stream: ReceiverStream<Result<WatchRequest, tonic::Status>> =
ReceiverStream::new(req_rx);
let header_gen = Arc::new(HeaderGenerator::new(0, 0));
let mut mock_watcher = MockKvWatcherOps::new();
let _ = mock_watcher.expect_watch().times(1).return_const(());
let _ = mock_watcher.expect_cancel().times(1).return_const(());
let watcher = Arc::new(mock_watcher);
let next_id = Arc::new(WatchIdGenerator::new(1));
let handle = tokio::spawn(WatchServer::<DB>::task(
next_id,
Arc::clone(&watcher),
res_tx,
req_stream,
header_gen,
Duration::from_millis(100),
));

req_tx
.send(Ok(WatchRequest {
request_union: Some(RequestUnion::CreateRequest(WatchCreateRequest {
key: "foo".into(),
progress_notify: true,
watch_id: 1,
..Default::default()
})),
}))
.await?;

drop(res_rx);

req_tx
.send(Ok(WatchRequest {
request_union: Some(RequestUnion::ProgressRequest(WatchProgressRequest {})),
}))
.await?;

assert!(timeout(Duration::from_secs(10), handle).await.is_ok());
Ok(())
}
}

0 comments on commit 06c9cc8

Please sign in to comment.