diff --git a/host/src/lib.rs b/host/src/lib.rs index a4df64dc..e3d14b37 100644 --- a/host/src/lib.rs +++ b/host/src/lib.rs @@ -152,6 +152,7 @@ pub struct ProverState { pub enum Message { Cancel(TaskDescriptor), Task(ProofRequest), + TaskComplete(ProofRequest), } impl From<&ProofRequest> for Message { @@ -192,9 +193,9 @@ impl ProverState { let opts_clone = opts.clone(); let chain_specs_clone = chain_specs.clone(); - + let sender = task_channel.clone(); tokio::spawn(async move { - ProofActor::new(receiver, opts_clone, chain_specs_clone) + ProofActor::new(sender, receiver, opts_clone, chain_specs_clone) .run() .await; }); diff --git a/host/src/proof.rs b/host/src/proof.rs index 31a56e72..e48cb07b 100644 --- a/host/src/proof.rs +++ b/host/src/proof.rs @@ -1,4 +1,7 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::{HashMap, VecDeque}, + sync::Arc, +}; use raiko_core::{ interfaces::{ProofRequest, RaikoError}, @@ -13,10 +16,13 @@ use raiko_lib::{ use raiko_tasks::{get_task_manager, TaskDescriptor, TaskManager, TaskManagerWrapper, TaskStatus}; use tokio::{ select, - sync::{mpsc::Receiver, Mutex, OwnedSemaphorePermit, Semaphore}, + sync::{ + mpsc::{Receiver, Sender}, + Mutex, + }, }; use tokio_util::sync::CancellationToken; -use tracing::{error, info, warn}; +use tracing::{debug, error, info, warn}; use crate::{ cache, @@ -32,26 +38,36 @@ use crate::{ pub struct ProofActor { opts: Opts, chain_specs: SupportedChainSpecs, - tasks: Arc>>, + running_tasks: Arc>>, + pending_tasks: Arc>>, receiver: Receiver, + sender: Sender, } impl ProofActor { - pub fn new(receiver: Receiver, opts: Opts, chain_specs: SupportedChainSpecs) -> Self { - let tasks = Arc::new(Mutex::new( + pub fn new( + sender: Sender, + receiver: Receiver, + opts: Opts, + chain_specs: SupportedChainSpecs, + ) -> Self { + let running_tasks = Arc::new(Mutex::new( HashMap::::new(), )); + let pending_tasks = Arc::new(Mutex::new(VecDeque::::new())); Self { - tasks, opts, chain_specs, + running_tasks, + pending_tasks, receiver, + sender, } } pub async fn cancel_task(&mut self, key: TaskDescriptor) -> HostResult<()> { - let tasks_map = self.tasks.lock().await; + let tasks_map = self.running_tasks.lock().await; let Some(task) = tasks_map.get(&key) else { warn!("No task with those keys to cancel"); return Ok(()); @@ -76,7 +92,7 @@ impl ProofActor { Ok(()) } - pub async fn run_task(&mut self, proof_request: ProofRequest, _permit: OwnedSemaphorePermit) { + pub async fn run_task(&mut self, proof_request: ProofRequest) { let cancel_token = CancellationToken::new(); let Ok((chain_id, blockhash)) = get_task_data( @@ -97,10 +113,11 @@ impl ProofActor { proof_request.prover.clone().to_string(), )); - let mut tasks = self.tasks.lock().await; + let mut tasks = self.running_tasks.lock().await; tasks.insert(key.clone(), cancel_token.clone()); + let sender = self.sender.clone(); - let tasks = self.tasks.clone(); + let tasks = self.running_tasks.clone(); let opts = self.opts.clone(); let chain_specs = self.chain_specs.clone(); @@ -109,7 +126,7 @@ impl ProofActor { _ = cancel_token.cancelled() => { info!("Task cancelled"); } - result = Self::handle_message(proof_request, key.clone(), &opts, &chain_specs) => { + result = Self::handle_message(proof_request.clone(), key.clone(), &opts, &chain_specs) => { match result { Ok(status) => { info!("Host handling message: {status:?}"); @@ -122,25 +139,56 @@ impl ProofActor { } let mut tasks = tasks.lock().await; tasks.remove(&key); + // notify complete task to let next pending task run + sender + .send(Message::TaskComplete(proof_request)) + .await + .expect("Couldn't send message"); }); } pub async fn run(&mut self) { - let semaphore = Arc::new(Semaphore::new(self.opts.concurrency_limit)); - + // recv() is protected by outside mpsc, no lock needed here while let Some(message) = self.receiver.recv().await { match message { Message::Cancel(key) => { + debug!("Message::Cancel task: {:?}", key); if let Err(error) = self.cancel_task(key).await { error!("Failed to cancel task: {error}") } } Message::Task(proof_request) => { - let permit = Arc::clone(&semaphore) - .acquire_owned() - .await - .expect("Couldn't acquire permit"); - self.run_task(proof_request, permit).await; + debug!("Message::Task proof_request: {:?}", proof_request); + let running_task_count = self.running_tasks.lock().await.len(); + if running_task_count < self.opts.concurrency_limit { + info!("Running task {:?}", proof_request); + self.run_task(proof_request).await; + } else { + info!( + "Task concurrency limit reached, current running {:?}, pending: {:?}", + running_task_count, + self.pending_tasks.lock().await.len() + ); + let mut pending_tasks = self.pending_tasks.lock().await; + pending_tasks.push_back(proof_request); + } + } + Message::TaskComplete(req) => { + // pop up pending task if any task complete + debug!("Message::TaskComplete: {:?}", req); + info!( + "task completed, current running {:?}, pending: {:?}", + self.running_tasks.lock().await.len(), + self.pending_tasks.lock().await.len() + ); + let mut pending_tasks = self.pending_tasks.lock().await; + if let Some(proof_request) = pending_tasks.pop_front() { + info!("Pop out pending task {:?}", proof_request); + self.sender + .send(Message::Task(proof_request)) + .await + .expect("Couldn't send message"); + } } } } @@ -190,7 +238,7 @@ pub async fn handle_proof( store: Option<&mut TaskManagerWrapper>, ) -> HostResult { info!( - "# Generating proof for block {} on {}", + "Generating proof for block {} on {}", proof_request.block_number, proof_request.network );