diff --git a/crates/hotshot/src/lib.rs b/crates/hotshot/src/lib.rs index 090a9687de..461638444a 100644 --- a/crates/hotshot/src/lib.rs +++ b/crates/hotshot/src/lib.rs @@ -17,6 +17,7 @@ use hotshot_types::{ traits::{network::BroadcastDelay, node_implementation::Versions}, }; use rand::Rng; +use tasks::add_health_check_task; use url::Url; /// Contains traits consumed by [`SystemContext`] @@ -634,6 +635,7 @@ impl, V: Versions> SystemContext(&mut handle).await; add_consensus_tasks::(&mut handle).await; + add_health_check_task::(&mut handle).await; handle } diff --git a/crates/hotshot/src/tasks/mod.rs b/crates/hotshot/src/tasks/mod.rs index 269c771ff0..0d8428469c 100644 --- a/crates/hotshot/src/tasks/mod.rs +++ b/crates/hotshot/src/tasks/mod.rs @@ -8,6 +8,7 @@ /// Provides trait to create task states from a `SystemContextHandle` pub mod task_state; +use hotshot_task::task::{NetworkHandle, Task}; use std::{collections::HashSet, sync::Arc, time::Duration}; use async_broadcast::broadcast; @@ -18,14 +19,14 @@ use futures::{ future::{BoxFuture, FutureExt}, stream, StreamExt, }; -use hotshot_task::task::Task; #[cfg(feature = "rewind")] use hotshot_task_impls::rewind::RewindTaskState; use hotshot_task_impls::{ da::DaTaskState, events::HotShotEvent, - network, - network::{NetworkEventTaskState, NetworkMessageTaskState}, + health_check::HealthCheckTaskState, + helpers::broadcast_event, + network::{self, NetworkEventTaskState, NetworkMessageTaskState}, request::NetworkRequestState, response::{run_response_task, NetworkResponseState}, transactions::TransactionTaskState, @@ -68,14 +69,7 @@ pub async fn add_request_network_task< >( handle: &mut SystemContextHandle, ) { - let state = NetworkRequestState::::create_from(handle).await; - - let task = Task::new( - state, - handle.internal_event_stream.0.clone(), - handle.internal_event_stream.1.activate_cloned(), - ); - handle.consensus_registry.run_task(task); + handle.add_task(NetworkRequestState::::create_from(handle).await); } /// Add a task which responds to requests on the network. @@ -91,9 +85,12 @@ pub fn add_response_task, V: Versi handle.private_key().clone(), handle.hotshot.id, ); + let task_name = state.get_task_name(); handle.network_registry.register(run_response_task::( state, + handle.internal_event_stream.0.clone(), handle.internal_event_stream.1.activate_cloned(), + handle.generate_task_id(task_name), )); } @@ -117,9 +114,10 @@ pub fn add_network_message_task< let network = Arc::clone(channel); let mut state = network_state.clone(); let shutdown_signal = create_shutdown_event_monitor(handle).fuse(); + let stream = handle.internal_event_stream.0.clone(); + let task_id = handle.generate_task_id(network_state.get_task_name()); + let handle_task_id = task_id.clone(); let task_handle = async_spawn(async move { - futures::pin_mut!(shutdown_signal); - let recv_stream = stream::unfold((), |()| async { let msgs = match network.recv_msgs().await { Ok(msgs) => { @@ -144,9 +142,10 @@ pub fn add_network_message_task< Some((msgs, ())) }); + let heartbeat_interval = + Task::>::get_periodic_interval_in_secs(); let fused_recv_stream = recv_stream.boxed().fuse(); - futures::pin_mut!(fused_recv_stream); - + futures::pin_mut!(fused_recv_stream, heartbeat_interval, shutdown_signal); loop { futures::select! { () = shutdown_signal => { @@ -168,10 +167,16 @@ pub fn add_network_message_task< return; } } + _ = Task::>::handle_periodic_delay(&mut heartbeat_interval) => { + broadcast_event(Arc::new(HotShotEvent::HeartBeat(handle_task_id.clone())), &stream).await; + } } } }); - handle.network_registry.register(task_handle); + handle.network_registry.register(NetworkHandle { + handle: task_handle, + task_id, + }); } /// Add the network task to handle events and send messages. @@ -194,12 +199,7 @@ pub fn add_network_event_task< storage: Arc::clone(&handle.storage()), upgrade_lock: handle.hotshot.upgrade_lock.clone(), }; - let task = Task::new( - network_state, - handle.internal_event_stream.0.clone(), - handle.internal_event_stream.1.activate_cloned(), - ); - handle.consensus_registry.run_task(task); + handle.add_task(network_state); } /// Adds consensus-related tasks to a `SystemContextHandle`. @@ -331,6 +331,7 @@ where add_consensus_tasks::(&mut handle).await; self.add_network_tasks(&mut handle).await; + add_health_check_task(&mut handle).await; handle } @@ -338,6 +339,7 @@ where /// Add byzantine network tasks with the trait #[allow(clippy::too_many_lines)] async fn add_network_tasks(&'static mut self, handle: &mut SystemContextHandle) { + let task_id = self.get_task_name(); let state_in = Arc::new(RwLock::new(self)); let state_out = Arc::clone(&state_in); // channels between the task spawned in this function and the network tasks. @@ -376,8 +378,6 @@ where // and broadcast the transformed events to the replacement event stream we just created. let shutdown_signal = create_shutdown_event_monitor(handle).fuse(); let send_handle = async_spawn(async move { - futures::pin_mut!(shutdown_signal); - let recv_stream = stream::unfold(original_receiver, |mut recv| async move { match recv.recv().await { Ok(event) => Some((Ok(event), recv)), @@ -388,7 +388,7 @@ where .boxed(); let fused_recv_stream = recv_stream.fuse(); - futures::pin_mut!(fused_recv_stream); + futures::pin_mut!(fused_recv_stream, shutdown_signal); loop { futures::select! { @@ -424,8 +424,6 @@ where // and broadcast the transformed events to the original internal event stream let shutdown_signal = create_shutdown_event_monitor(handle).fuse(); let recv_handle = async_spawn(async move { - futures::pin_mut!(shutdown_signal); - let network_recv_stream = stream::unfold(receiver_from_network, |mut recv| async move { match recv.recv().await { @@ -436,7 +434,7 @@ where }); let fused_network_recv_stream = network_recv_stream.boxed().fuse(); - futures::pin_mut!(fused_network_recv_stream); + futures::pin_mut!(fused_network_recv_stream, shutdown_signal); loop { futures::select! { @@ -467,8 +465,19 @@ where } }); - handle.network_registry.register(send_handle); - handle.network_registry.register(recv_handle); + handle.network_registry.register(NetworkHandle { + handle: send_handle, + task_id: handle.generate_task_id(task_id), + }); + handle.network_registry.register(NetworkHandle { + handle: recv_handle, + task_id: handle.generate_task_id(task_id), + }); + } + + /// Gets the name of the current task + fn get_task_name(&self) -> &'static str { + std::any::type_name::>() } } @@ -659,3 +668,10 @@ pub async fn add_network_tasks, V: network::vid_filter, ); } + +/// Add the health check task +pub async fn add_health_check_task, V: Versions>( + handle: &mut SystemContextHandle, +) { + handle.add_task(HealthCheckTaskState::::create_from(handle).await); +} diff --git a/crates/hotshot/src/tasks/task_state.rs b/crates/hotshot/src/tasks/task_state.rs index 6cd8e14077..5391283afb 100644 --- a/crates/hotshot/src/tasks/task_state.rs +++ b/crates/hotshot/src/tasks/task_state.rs @@ -13,7 +13,7 @@ use async_trait::async_trait; use chrono::Utc; use hotshot_task_impls::{ builder::BuilderClient, consensus::ConsensusTaskState, consensus2::Consensus2TaskState, - da::DaTaskState, quorum_proposal::QuorumProposalTaskState, + da::DaTaskState, health_check::HealthCheckTaskState, quorum_proposal::QuorumProposalTaskState, quorum_proposal_recv::QuorumProposalRecvTaskState, quorum_vote::QuorumVoteTaskState, request::NetworkRequestState, rewind::RewindTaskState, transactions::TransactionTaskState, upgrade::UpgradeTaskState, vid::VidTaskState, view_sync::ViewSyncTaskState, @@ -383,3 +383,17 @@ impl, V: Versions> CreateTaskState } } } + +#[async_trait] +impl, V: Versions> CreateTaskState + for HealthCheckTaskState +{ + async fn create_from(handle: &SystemContextHandle) -> Self { + let heartbeat_timeout_duration_in_secs = 30; + HealthCheckTaskState::new( + handle.hotshot.id, + handle.get_task_ids(), + heartbeat_timeout_duration_in_secs, + ) + } +} diff --git a/crates/hotshot/src/types/handle.rs b/crates/hotshot/src/types/handle.rs index 0b285b593f..8ba28c6c5e 100644 --- a/crates/hotshot/src/types/handle.rs +++ b/crates/hotshot/src/types/handle.rs @@ -22,6 +22,7 @@ use hotshot_types::{ error::HotShotError, traits::{election::Membership, network::ConnectedNetwork, node_implementation::NodeType}, }; +use rand::Rng; #[cfg(async_executor_impl = "tokio")] use tokio::task::JoinHandle; use tracing::instrument; @@ -68,15 +69,34 @@ impl + 'static, V: Versions> { /// Adds a hotshot consensus-related task to the `SystemContextHandle`. pub fn add_task> + 'static>(&mut self, task_state: S) { + let task_name = task_state.get_task_name(); let task = Task::new( task_state, self.internal_event_stream.0.clone(), self.internal_event_stream.1.activate_cloned(), + self.generate_task_id(task_name), ); self.consensus_registry.run_task(task); } + #[must_use] + /// generate a task id for a task + pub fn generate_task_id(&self, task_name: &str) -> String { + let random = rand::thread_rng().gen_range(0..=9999); + let tasks_spawned = + self.consensus_registry.task_handles.len() + self.network_registry.handles.len(); + format!("{task_name}_{tasks_spawned}_{random}") + } + + #[must_use] + /// Get a list of all the running tasks ids + pub fn get_task_ids(&self) -> Vec { + let mut task_ids = self.consensus_registry.get_task_ids(); + task_ids.extend(self.network_registry.get_task_ids()); + task_ids + } + /// obtains a stream to expose to the user pub fn event_stream(&self) -> impl Stream> { self.output_event_stream.1.activate_cloned() diff --git a/crates/task-impls/src/consensus/mod.rs b/crates/task-impls/src/consensus/mod.rs index 67ae099035..a9cf49aa73 100644 --- a/crates/task-impls/src/consensus/mod.rs +++ b/crates/task-impls/src/consensus/mod.rs @@ -773,4 +773,8 @@ impl, V: Versions> TaskState } } } + + fn get_task_name(&self) -> &'static str { + std::any::type_name::>() + } } diff --git a/crates/task-impls/src/consensus2/mod.rs b/crates/task-impls/src/consensus2/mod.rs index 65c80cab34..1e74508278 100644 --- a/crates/task-impls/src/consensus2/mod.rs +++ b/crates/task-impls/src/consensus2/mod.rs @@ -172,4 +172,8 @@ impl, V: Versions> TaskState /// Joins all subtasks. async fn cancel_subtasks(&mut self) {} + + fn get_task_name(&self) -> &'static str { + std::any::type_name::>() + } } diff --git a/crates/task-impls/src/da.rs b/crates/task-impls/src/da.rs index 0cf96e4fea..bc38aff941 100644 --- a/crates/task-impls/src/da.rs +++ b/crates/task-impls/src/da.rs @@ -379,4 +379,8 @@ impl> TaskState for DaTaskState &'static str { + std::any::type_name::>() + } } diff --git a/crates/task-impls/src/events.rs b/crates/task-impls/src/events.rs index 082caef45c..bf535af95e 100644 --- a/crates/task-impls/src/events.rs +++ b/crates/task-impls/src/events.rs @@ -36,6 +36,10 @@ impl TaskEvent for HotShotEvent { fn shutdown_event() -> Self { HotShotEvent::Shutdown } + + fn heartbeat_event(task_id: String) -> Self { + HotShotEvent::HeartBeat(task_id) + } } /// Wrapper type for the event to notify tasks that a proposal for a view is missing @@ -216,6 +220,9 @@ pub enum HotShotEvent { /// 2. The proposal has been correctly signed by the leader of the current view /// 3. The justify QC is valid QuorumProposalPreliminarilyValidated(Proposal>), + + /// Periodic heart beat event for health checking + HeartBeat(String), } impl Display for HotShotEvent { @@ -463,6 +470,9 @@ impl Display for HotShotEvent { proposal.data.view_number() ) } + HotShotEvent::HeartBeat(task_id) => { + write!(f, "HeartBeat(task_id={task_id:?}") + } } } } diff --git a/crates/task-impls/src/harness.rs b/crates/task-impls/src/harness.rs index 16e8a273b8..5f116dee31 100644 --- a/crates/task-impls/src/harness.rs +++ b/crates/task-impls/src/harness.rs @@ -50,7 +50,12 @@ pub async fn run_harness> + Send allow_extra_output, }; - let task = Task::new(state, to_test.clone(), from_test.clone()); + let task = Task::new( + state, + to_test.clone(), + from_test.clone(), + "task_0".to_string(), + ); let handle = task.run(); let test_future = async move { diff --git a/crates/task-impls/src/health_check.rs b/crates/task-impls/src/health_check.rs new file mode 100644 index 0000000000..ee00dbbe7f --- /dev/null +++ b/crates/task-impls/src/health_check.rs @@ -0,0 +1,121 @@ +// Copyright (c) 2021-2024 Espresso Systems (espressosys.com) +// This file is part of the HotShot repository. + +// You should have received a copy of the MIT License +// along with the HotShot repository. If not, see . + +use std::{ + collections::{hash_map::Entry, HashMap}, + marker::PhantomData, + sync::Arc, + time::Instant, +}; + +use anyhow::Result; +use async_broadcast::{Receiver, Sender}; +use async_lock::Mutex; +use async_trait::async_trait; +use hotshot_task::task::TaskState; +use hotshot_types::traits::node_implementation::NodeType; + +use crate::events::{HotShotEvent, HotShotTaskCompleted}; + +/// Health event task, recieve heart beats from other tasks +pub struct HealthCheckTaskState { + /// Node id + pub node_id: u64, + /// Map of the task id to timestamp of last heartbeat + pub task_ids_heartbeat_timestamp: Mutex>, + /// Specify the time we start logging when no heartbeat received + pub heartbeat_timeout_duration_in_secs: u64, + /// phantom + pub _phantom: PhantomData, +} + +impl HealthCheckTaskState { + /// Create a new instance of task state with task ids pre populated + #[must_use] + pub fn new( + node_id: u64, + task_ids: Vec, + heartbeat_timeout_duration_in_secs: u64, + ) -> Self { + let time = Instant::now(); + let mut task_ids_heartbeat_timestamp: HashMap = HashMap::new(); + for task_id in task_ids { + task_ids_heartbeat_timestamp.insert(task_id, time); + } + + HealthCheckTaskState { + node_id, + task_ids_heartbeat_timestamp: Mutex::new(task_ids_heartbeat_timestamp), + heartbeat_timeout_duration_in_secs, + _phantom: std::marker::PhantomData, + } + } + /// Handles only HeartBeats and updates the timestamp for a task + pub async fn handle( + &mut self, + event: &Arc>, + ) -> Option { + match event.as_ref() { + HotShotEvent::HeartBeat(task_id) => { + let mut task_ids_heartbeat_timestamp = + self.task_ids_heartbeat_timestamp.lock().await; + match task_ids_heartbeat_timestamp.entry(task_id.clone()) { + Entry::Occupied(mut heartbeat_timestamp) => { + *heartbeat_timestamp.get_mut() = Instant::now(); + } + Entry::Vacant(_) => { + // On startup of this task we populate the map with all task ids + } + } + } + HotShotEvent::Shutdown => { + return Some(HotShotTaskCompleted); + } + _ => {} + } + None + } +} + +#[async_trait] +impl TaskState for HealthCheckTaskState { + type Event = HotShotEvent; + + async fn handle_event( + &mut self, + event: Arc, + _sender: &Sender>, + _receiver: &Receiver>, + ) -> Result<()> { + self.handle(&event).await; + + Ok(()) + } + + async fn cancel_subtasks(&mut self) {} + + async fn periodic_task(&self, _sender: &Sender>, _task_id: String) { + let current_time = Instant::now(); + + let task_ids_heartbeat = self.task_ids_heartbeat_timestamp.lock().await; + for (task_id, heartbeat_timestamp) in task_ids_heartbeat.iter() { + if current_time.duration_since(*heartbeat_timestamp).as_secs() + > self.heartbeat_timeout_duration_in_secs + { + tracing::error!( + "Node Id {} has not received a heartbeat for task id {} for {} seconds", + self.node_id, + task_id, + self.heartbeat_timeout_duration_in_secs + ); + } + } + } + + fn get_task_name(&self) -> &'static str { + std::any::type_name::>() + } +} diff --git a/crates/task-impls/src/lib.rs b/crates/task-impls/src/lib.rs index ed3dc5a0ee..b97685d9e8 100644 --- a/crates/task-impls/src/lib.rs +++ b/crates/task-impls/src/lib.rs @@ -64,3 +64,6 @@ pub mod quorum_proposal_recv; /// Task for storing and replaying all received tasks by a node pub mod rewind; + +/// Task for listening to HeartBeat events and logging any task that doesnt broadcast after sometime +pub mod health_check; diff --git a/crates/task-impls/src/network.rs b/crates/task-impls/src/network.rs index 4f94f37840..52b94e2cf5 100644 --- a/crates/task-impls/src/network.rs +++ b/crates/task-impls/src/network.rs @@ -209,6 +209,12 @@ impl NetworkMessageTaskState { .await; } } + + /// Gets the name of the current task + #[must_use] + pub fn get_task_name(&self) -> &'static str { + std::any::type_name::>() + } } /// network event task state @@ -259,6 +265,10 @@ impl< } async fn cancel_subtasks(&mut self) {} + + fn get_task_name(&self) -> &'static str { + std::any::type_name::>() + } } impl< diff --git a/crates/task-impls/src/quorum_proposal/mod.rs b/crates/task-impls/src/quorum_proposal/mod.rs index cd382c38f5..8fc4c3141e 100644 --- a/crates/task-impls/src/quorum_proposal/mod.rs +++ b/crates/task-impls/src/quorum_proposal/mod.rs @@ -531,4 +531,8 @@ impl, V: Versions> TaskState handle.abort(); } } + + fn get_task_name(&self) -> &'static str { + std::any::type_name::>() + } } diff --git a/crates/task-impls/src/quorum_proposal_recv/mod.rs b/crates/task-impls/src/quorum_proposal_recv/mod.rs index dc0fb5124c..24436cc450 100644 --- a/crates/task-impls/src/quorum_proposal_recv/mod.rs +++ b/crates/task-impls/src/quorum_proposal_recv/mod.rs @@ -176,4 +176,8 @@ impl, V: Versions> TaskState } } } + + fn get_task_name(&self) -> &'static str { + std::any::type_name::>() + } } diff --git a/crates/task-impls/src/quorum_vote/mod.rs b/crates/task-impls/src/quorum_vote/mod.rs index 4ddd531cab..969d99ba26 100644 --- a/crates/task-impls/src/quorum_vote/mod.rs +++ b/crates/task-impls/src/quorum_vote/mod.rs @@ -675,4 +675,8 @@ impl, V: Versions> TaskState handle.abort(); } } + + fn get_task_name(&self) -> &'static str { + std::any::type_name::>() + } } diff --git a/crates/task-impls/src/request.rs b/crates/task-impls/src/request.rs index ef0faf0ed5..d587a927a7 100644 --- a/crates/task-impls/src/request.rs +++ b/crates/task-impls/src/request.rs @@ -146,6 +146,10 @@ impl> TaskState for NetworkRequest } } } + + fn get_task_name(&self) -> &'static str { + std::any::type_name::>() + } } impl> NetworkRequestState { diff --git a/crates/task-impls/src/response.rs b/crates/task-impls/src/response.rs index e5d3563089..7d0329b140 100644 --- a/crates/task-impls/src/response.rs +++ b/crates/task-impls/src/response.rs @@ -6,12 +6,13 @@ use std::{sync::Arc, time::Duration}; -use async_broadcast::Receiver; +use async_broadcast::{Receiver, Sender}; use async_compatibility_layer::art::{async_sleep, async_spawn}; -#[cfg(async_executor_impl = "async-std")] -use async_std::task::JoinHandle; use futures::{FutureExt, StreamExt}; -use hotshot_task::dependency::{Dependency, EventDependency}; +use hotshot_task::{ + dependency::{Dependency, EventDependency}, + task::{NetworkHandle, Task}, +}; use hotshot_types::{ consensus::{Consensus, LockedConsensusState, OuterConsensus}, data::VidDisperseShare, @@ -28,11 +29,9 @@ use hotshot_types::{ }, }; use sha2::{Digest, Sha256}; -#[cfg(async_executor_impl = "tokio")] -use tokio::task::JoinHandle; use tracing::instrument; -use crate::events::HotShotEvent; +use crate::{events::HotShotEvent, health_check::HealthCheckTaskState, helpers::broadcast_event}; /// Time to wait for txns before sending `ResponseMessage::NotFound` const TXNS_TIMEOUT: Duration = Duration::from_millis(100); @@ -76,8 +75,16 @@ impl NetworkResponseState { /// Run the request response loop until a `HotShotEvent::Shutdown` is received. /// Or the stream is closed. - async fn run_loop(mut self, shutdown: EventDependency>>) { + async fn run_loop( + mut self, + shutdown: EventDependency>>, + sender: Sender>>, + task_name: String, + ) { let mut shutdown = Box::pin(shutdown.completed().fuse()); + let heartbeat_interval = + Task::>::get_periodic_interval_in_secs(); + futures::pin_mut!(heartbeat_interval); loop { futures::select! { req = self.receiver.next() => { @@ -86,6 +93,9 @@ impl NetworkResponseState { None => return, } }, + _ = Task::>::handle_periodic_delay(&mut heartbeat_interval) => { + broadcast_event(Arc::new(HotShotEvent::HeartBeat(task_name.clone())), &sender).await; + }, _ = shutdown => { return; } @@ -231,6 +241,11 @@ impl NetworkResponseState { None => ResponseMessage::NotFound, } } + + /// Get the task name + pub fn get_task_name(&self) -> &'static str { + std::any::type_name::>() + } } /// Check the signature @@ -249,11 +264,14 @@ fn valid_signature( /// on the `event_stream` arg. pub fn run_response_task( task_state: NetworkResponseState, - event_stream: Receiver>>, -) -> JoinHandle<()> { - let dep = EventDependency::new( - event_stream, + sender: Sender>>, + receiver: Receiver>>, + task_id: String, +) -> NetworkHandle { + let shutdown = EventDependency::new( + receiver, Box::new(|e| matches!(e.as_ref(), HotShotEvent::Shutdown)), ); - async_spawn(task_state.run_loop(dep)) + let handle = async_spawn(task_state.run_loop(shutdown, sender, task_id.clone())); + NetworkHandle { handle, task_id } } diff --git a/crates/task-impls/src/rewind.rs b/crates/task-impls/src/rewind.rs index 669b410b52..33907f71c8 100644 --- a/crates/task-impls/src/rewind.rs +++ b/crates/task-impls/src/rewind.rs @@ -70,4 +70,8 @@ impl TaskState for RewindTaskState { } } } + + fn get_task_name(&self) -> &'static str { + std::any::type_name::>() + } } diff --git a/crates/task-impls/src/transactions.rs b/crates/task-impls/src/transactions.rs index fc10b549ee..e6238d74a7 100644 --- a/crates/task-impls/src/transactions.rs +++ b/crates/task-impls/src/transactions.rs @@ -792,4 +792,8 @@ impl, V: Versions> TaskState } async fn cancel_subtasks(&mut self) {} + + fn get_task_name(&self) -> &'static str { + std::any::type_name::>() + } } diff --git a/crates/task-impls/src/upgrade.rs b/crates/task-impls/src/upgrade.rs index 680c61d82c..1a05629879 100644 --- a/crates/task-impls/src/upgrade.rs +++ b/crates/task-impls/src/upgrade.rs @@ -358,4 +358,8 @@ impl, V: Versions> TaskState } async fn cancel_subtasks(&mut self) {} + + fn get_task_name(&self) -> &'static str { + std::any::type_name::>() + } } diff --git a/crates/task-impls/src/vid.rs b/crates/task-impls/src/vid.rs index 3243f356ae..54e40ef177 100644 --- a/crates/task-impls/src/vid.rs +++ b/crates/task-impls/src/vid.rs @@ -171,4 +171,8 @@ impl> TaskState for VidTaskState &'static str { + std::any::type_name::>() + } } diff --git a/crates/task-impls/src/view_sync.rs b/crates/task-impls/src/view_sync.rs index e2c4e5cb95..2e9544a5c6 100644 --- a/crates/task-impls/src/view_sync.rs +++ b/crates/task-impls/src/view_sync.rs @@ -120,6 +120,10 @@ impl> TaskState for ViewSyncTaskSt } async fn cancel_subtasks(&mut self) {} + + fn get_task_name(&self) -> &'static str { + std::any::type_name::>() + } } /// State of a view sync replica task @@ -169,6 +173,10 @@ impl> TaskState } async fn cancel_subtasks(&mut self) {} + + fn get_task_name(&self) -> &'static str { + std::any::type_name::>() + } } impl> ViewSyncTaskState { diff --git a/crates/task/src/task.rs b/crates/task/src/task.rs index af195d1e27..4ec0986c91 100644 --- a/crates/task/src/task.rs +++ b/crates/task/src/task.rs @@ -4,7 +4,7 @@ // You should have received a copy of the MIT License // along with the HotShot repository. If not, see . -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use anyhow::Result; use async_broadcast::{Receiver, Sender}; @@ -16,6 +16,9 @@ use futures::future::join_all; #[cfg(async_executor_impl = "tokio")] use futures::future::try_join_all; #[cfg(async_executor_impl = "tokio")] +use futures::FutureExt; +use futures::StreamExt; +#[cfg(async_executor_impl = "tokio")] use tokio::task::{spawn, JoinHandle}; /// Trait for events that long-running tasks handle @@ -25,11 +28,14 @@ pub trait TaskEvent: PartialEq { /// Note that this is necessarily uniform across all tasks. /// Exiting the task loop is handled by the task spawner, rather than the task individually. fn shutdown_event() -> Self; + + /// The heartbeat event + fn heartbeat_event(task_id: String) -> Self; } #[async_trait] /// Type for mutable task state that can be used as the state for a `Task` -pub trait TaskState: Send { +pub trait TaskState: Send + Sync { /// Type of event sent and received by the task type Event: TaskEvent + Clone + Send + Sync; @@ -43,6 +49,27 @@ pub trait TaskState: Send { _sender: &Sender>, _receiver: &Receiver>, ) -> Result<()>; + + /// Runs a specified job in the main task every `Task::PERIODIC_INTERVAL_IN_SECS` + async fn periodic_task(&self, sender: &Sender>, task_id: String) { + match sender + .broadcast_direct(Arc::new(Self::Event::heartbeat_event(task_id))) + .await + { + Ok(None) => (), + Ok(Some(_overflowed)) => { + tracing::error!( + "Event sender queue overflow, Oldest event removed form queue: Heartbeat Event" + ); + } + Err(async_broadcast::SendError(_e)) => { + tracing::warn!("Event: Heartbeat\n Sending failed, event stream probably shutdown"); + } + } + } + + /// Gets the name of the current task + fn get_task_name(&self) -> &'static str; } /// A basic task which loops waiting for events to come from `event_receiver` @@ -58,15 +85,26 @@ pub struct Task { sender: Sender>, /// Receives events that are broadcast from any task, including itself receiver: Receiver>, + /// The generated task id + task_id: String, } impl Task { + /// Constant for how often we run our periodic tasks, such as broadcasting a hearbeat + const PERIODIC_INTERVAL_IN_SECS: u64 = 10; + /// Create a new task - pub fn new(state: S, sender: Sender>, receiver: Receiver>) -> Self { + pub fn new( + state: S, + sender: Sender>, + receiver: Receiver>, + task_id: String, + ) -> Self { Task { state, sender, receiver, + task_id, } } @@ -75,38 +113,100 @@ impl Task { Box::new(self.state) as Box> } + #[cfg(async_executor_impl = "async-std")] + /// Periodic delay + pub fn get_periodic_interval_in_secs() -> futures::stream::Fuse { + async_std::stream::interval(Duration::from_secs(Self::PERIODIC_INTERVAL_IN_SECS)).fuse() + } + + #[cfg(async_executor_impl = "async-std")] + /// Handle periodic delay interval + pub fn handle_periodic_delay( + periodic_interval: &mut futures::stream::Fuse, + ) -> futures::stream::Next<'_, futures::stream::Fuse> { + periodic_interval.next() + } + + #[cfg(async_executor_impl = "tokio")] + #[must_use] + /// Periodic delay + pub fn get_periodic_interval_in_secs() -> tokio::time::Interval { + tokio::time::interval(Duration::from_secs(Self::PERIODIC_INTERVAL_IN_SECS)) + } + + #[cfg(async_executor_impl = "tokio")] + /// Handle periodic delay interval + pub fn handle_periodic_delay( + periodic_interval: &mut tokio::time::Interval, + ) -> futures::future::Fuse + '_> { + periodic_interval.tick().fuse() + } + /// Spawn the task loop, consuming self. Will continue until /// the task reaches some shutdown condition - pub fn run(mut self) -> JoinHandle>> { - spawn(async move { - loop { - match self.receiver.recv_direct().await { - Ok(input) => { - if *input == S::Event::shutdown_event() { - self.state.cancel_subtasks().await; + pub fn run(mut self) -> HotShotTaskHandle { + let task_id = self.task_id.clone(); + let handle = spawn(async move { + let recv_stream = + futures::stream::unfold(self.receiver.clone(), |mut recv| async move { + match recv.recv_direct().await { + Ok(event) => Some((Ok(event), recv)), + Err(e) => Some((Err(e), recv)), + } + }) + .boxed(); - break self.boxed_state(); - } + let fused_recv_stream = recv_stream.fuse(); + let periodic_interval = Self::get_periodic_interval_in_secs(); + futures::pin_mut!(periodic_interval, fused_recv_stream); + loop { + futures::select! { + input = fused_recv_stream.next() => { + match input { + Some(Ok(input)) => { + if *input == S::Event::shutdown_event() { + self.state.cancel_subtasks().await; - let _ = - S::handle_event(&mut self.state, input, &self.sender, &self.receiver) + break self.boxed_state(); + } + let _ = S::handle_event( + &mut self.state, + input, + &self.sender, + &self.receiver, + ) .await .inspect_err(|e| tracing::info!("{e}")); + } + Some(Err(e)) => { + tracing::error!("Failed to receive from event stream Error: {}", e); + } + None => {} + } } - Err(e) => { - tracing::error!("Failed to receive from event stream Error: {}", e); - } + _ = Self::handle_periodic_delay(&mut periodic_interval) => { + self.state.periodic_task(&self.sender, self.task_id.clone()).await; + }, } } - }) + }); + HotShotTaskHandle { handle, task_id } } } +/// Wrapper around handle and task id so we can map +pub struct HotShotTaskHandle { + /// Handle for the task + pub handle: JoinHandle>>, + /// Generated task id + pub task_id: String, +} + #[derive(Default)] /// A collection of tasks which can handle shutdown pub struct ConsensusTaskRegistry { /// Tasks this registry controls - task_handles: Vec>>>, + pub task_handles: Vec>, } impl ConsensusTaskRegistry { @@ -117,10 +217,21 @@ impl ConsensusTaskRegistry { task_handles: vec![], } } + /// Add a task to the registry - pub fn register(&mut self, handle: JoinHandle>>) { + pub fn register(&mut self, handle: HotShotTaskHandle) { self.task_handles.push(handle); } + + #[must_use] + /// Get all task ids from registry + pub fn get_task_ids(&self) -> Vec { + self.task_handles + .iter() + .map(|wrapped_handle| wrapped_handle.task_id.clone()) + .collect() + } + /// Try to cancel/abort the task this registry has /// /// # Panics @@ -129,11 +240,11 @@ impl ConsensusTaskRegistry { pub async fn shutdown(&mut self) { let handles = &mut self.task_handles; - while let Some(handle) = handles.pop() { + while let Some(wrapped_handle) = handles.pop() { #[cfg(async_executor_impl = "async-std")] - let mut task_state = handle.await; + let mut task_state = wrapped_handle.handle.await; #[cfg(async_executor_impl = "tokio")] - let mut task_state = handle.await.unwrap(); + let mut task_state = wrapped_handle.handle.await.unwrap(); task_state.cancel_subtasks().await; } @@ -150,20 +261,33 @@ impl ConsensusTaskRegistry { /// # Panics /// Panics if one of the tasks panicked pub async fn join_all(self) -> Vec>> { + let handles: Vec>>> = self + .task_handles + .into_iter() + .map(|wrapped| wrapped.handle) + .collect(); #[cfg(async_executor_impl = "async-std")] - let states = join_all(self.task_handles).await; + let states = join_all(handles).await; #[cfg(async_executor_impl = "tokio")] - let states = try_join_all(self.task_handles).await.unwrap(); + let states = try_join_all(handles).await.unwrap(); states } } +/// Wrapper around join handle and task id for network tasks +pub struct NetworkHandle { + /// Task handle + pub handle: JoinHandle<()>, + /// Generated task id + pub task_id: String, +} + #[derive(Default)] /// A collection of tasks which can handle shutdown pub struct NetworkTaskRegistry { /// Tasks this registry controls - pub handles: Vec>, + pub handles: Vec, } impl NetworkTaskRegistry { @@ -173,6 +297,15 @@ impl NetworkTaskRegistry { NetworkTaskRegistry { handles: vec![] } } + #[must_use] + /// Get all task ids from registry + pub fn get_task_ids(&self) -> Vec { + self.handles + .iter() + .map(|wrapped_handle| wrapped_handle.task_id.clone()) + .collect() + } + #[allow(clippy::unused_async)] /// Shuts down all tasks managed by this instance. /// @@ -184,16 +317,18 @@ impl NetworkTaskRegistry { /// tasks being joined return an error. pub async fn shutdown(&mut self) { let handles = std::mem::take(&mut self.handles); + let task_handles: Vec> = + handles.into_iter().map(|wrapped| wrapped.handle).collect(); #[cfg(async_executor_impl = "async-std")] - join_all(handles).await; + join_all(task_handles).await; #[cfg(async_executor_impl = "tokio")] - try_join_all(handles) + try_join_all(task_handles) .await .expect("Failed to join all tasks during shutdown"); } /// Add a task to the registry - pub fn register(&mut self, handle: JoinHandle<()>) { + pub fn register(&mut self, handle: NetworkHandle) { self.handles.push(handle); } } diff --git a/crates/testing/tests/tests_1/network_task.rs b/crates/testing/tests/tests_1/network_task.rs index a3b3245533..1631d0be19 100644 --- a/crates/testing/tests/tests_1/network_task.rs +++ b/crates/testing/tests/tests_1/network_task.rs @@ -4,13 +4,12 @@ // You should have received a copy of the MIT License // along with the HotShot repository. If not, see . -use std::{sync::Arc, time::Duration}; - use async_broadcast::Sender; use async_compatibility_layer::art::async_timeout; use async_lock::RwLock; use hotshot::traits::implementations::MemoryNetwork; use hotshot_example_types::node_types::{MemoryImpl, TestTypes, TestVersions}; +use hotshot_task::task::TaskState; use hotshot_task::task::{ConsensusTaskRegistry, Task}; use hotshot_task_impls::{ events::HotShotEvent, @@ -28,6 +27,7 @@ use hotshot_types::{ node_implementation::{ConsensusTime, NodeType}, }, }; +use std::{sync::Arc, time::Duration}; // Test that the event task sends a message, and the message task receives it // and emits the proper event @@ -74,7 +74,8 @@ async fn test_network_task() { let (tx, rx) = async_broadcast::broadcast(10); let mut task_reg = ConsensusTaskRegistry::new(); - let task = Task::new(network_state, tx.clone(), rx); + let task_name = network_state.get_task_name(); + let task = Task::new(network_state, tx.clone(), rx, task_name.to_string()); task_reg.run_task(task); let mut generator = TestViewGenerator::generate(membership.clone(), membership); @@ -150,7 +151,12 @@ async fn test_network_storage_fail() { let (tx, rx) = async_broadcast::broadcast(10); let mut task_reg = ConsensusTaskRegistry::new(); - let task = Task::new(network_state, tx.clone(), rx); + let task = Task::new( + network_state, + tx.clone(), + rx, + "NetworkEventTaskState_0".to_string(), + ); task_reg.run_task(task); let mut generator = TestViewGenerator::generate(membership.clone(), membership); diff --git a/crates/testing/tests/tests_1/test_with_failures_2.rs b/crates/testing/tests/tests_1/test_with_failures_2.rs index 54d3e5193f..9d551629b8 100644 --- a/crates/testing/tests/tests_1/test_with_failures_2.rs +++ b/crates/testing/tests/tests_1/test_with_failures_2.rs @@ -24,7 +24,6 @@ use hotshot_testing::{ view_sync_task::ViewSyncTaskDescription, }; use hotshot_types::{data::ViewNumber, traits::node_implementation::ConsensusTime}; -#[cfg(async_executor_impl = "async-std")] use {hotshot::tasks::DishonestLeader, hotshot_testing::test_builder::Behaviour, std::rc::Rc}; // Test that a good leader can succeed in the view directly after view sync cross_tests!( @@ -68,7 +67,6 @@ cross_tests!( } ); -#[cfg(async_executor_impl = "async-std")] cross_tests!( TestName: dishonest_leader, Impls: [MemoryImpl],