Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(raiko): put the tasks that cannot run in parallel into pending list #358

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions host/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ pub struct ProverState {
pub enum Message {
Cancel(TaskDescriptor),
Task(ProofRequest),
TaskComplete(ProofRequest),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use oneshot channel to notify the watcher the result of tasks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean notify one-shot channel & let it send task in??

}

impl From<&ProofRequest> for Message {
Expand Down Expand Up @@ -184,9 +185,9 @@ impl ProverState {

let opts_clone = opts.clone();
let chain_specs_clone = chain_specs.clone();

let sender = task_channel.clone();
tokio::spawn(async move {
ProofActor::new(receiver, opts_clone, chain_specs_clone)
ProofActor::new(sender, receiver, opts_clone, chain_specs_clone)
.run()
.await;
});
Expand Down
88 changes: 68 additions & 20 deletions host/src/proof.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::{collections::HashMap, sync::Arc};
use std::{
collections::{HashMap, VecDeque},
sync::Arc,
};

use raiko_core::{
interfaces::{ProofRequest, RaikoError},
Expand All @@ -13,10 +16,13 @@ use raiko_lib::{
use raiko_tasks::{get_task_manager, TaskDescriptor, TaskManager, TaskManagerWrapper, TaskStatus};
use tokio::{
select,
sync::{mpsc::Receiver, Mutex, OwnedSemaphorePermit, Semaphore},
sync::{
mpsc::{Receiver, Sender},
Mutex,
},
};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
use tracing::{debug, error, info, warn};

use crate::{
cache,
Expand All @@ -32,26 +38,36 @@ use crate::{
pub struct ProofActor {
opts: Opts,
chain_specs: SupportedChainSpecs,
tasks: Arc<Mutex<HashMap<TaskDescriptor, CancellationToken>>>,
running_tasks: Arc<Mutex<HashMap<TaskDescriptor, CancellationToken>>>,
pending_tasks: Arc<Mutex<VecDeque<ProofRequest>>>,
receiver: Receiver<Message>,
sender: Sender<Message>,
}

impl ProofActor {
pub fn new(receiver: Receiver<Message>, opts: Opts, chain_specs: SupportedChainSpecs) -> Self {
let tasks = Arc::new(Mutex::new(
pub fn new(
sender: Sender<Message>,
receiver: Receiver<Message>,
opts: Opts,
chain_specs: SupportedChainSpecs,
) -> Self {
let running_tasks = Arc::new(Mutex::new(
HashMap::<TaskDescriptor, CancellationToken>::new(),
));
let pending_tasks = Arc::new(Mutex::new(VecDeque::<ProofRequest>::new()));

Self {
tasks,
opts,
chain_specs,
running_tasks,
pending_tasks,
receiver,
sender,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Rust, we usually use rx and tx to replace receiver and sender.

}
}

pub async fn cancel_task(&mut self, key: TaskDescriptor) -> HostResult<()> {
let tasks_map = self.tasks.lock().await;
let tasks_map = self.running_tasks.lock().await;
let Some(task) = tasks_map.get(&key) else {
warn!("No task with those keys to cancel");
return Ok(());
Expand All @@ -76,7 +92,7 @@ impl ProofActor {
Ok(())
}

pub async fn run_task(&mut self, proof_request: ProofRequest, _permit: OwnedSemaphorePermit) {
pub async fn run_task(&mut self, proof_request: ProofRequest) {
let cancel_token = CancellationToken::new();

let Ok((chain_id, blockhash)) = get_task_data(
Expand All @@ -97,10 +113,11 @@ impl ProofActor {
proof_request.prover.clone().to_string(),
));

let mut tasks = self.tasks.lock().await;
let mut tasks = self.running_tasks.lock().await;
tasks.insert(key.clone(), cancel_token.clone());
let sender = self.sender.clone();

let tasks = self.tasks.clone();
let tasks = self.running_tasks.clone();
let opts = self.opts.clone();
let chain_specs = self.chain_specs.clone();

Expand All @@ -109,7 +126,7 @@ impl ProofActor {
_ = cancel_token.cancelled() => {
info!("Task cancelled");
}
result = Self::handle_message(proof_request, key.clone(), &opts, &chain_specs) => {
result = Self::handle_message(proof_request.clone(), key.clone(), &opts, &chain_specs) => {
match result {
Ok(()) => {
info!("Host handling message");
Expand All @@ -122,25 +139,56 @@ impl ProofActor {
}
let mut tasks = tasks.lock().await;
tasks.remove(&key);
// notify complete task to let next pending task run
sender
.send(Message::TaskComplete(proof_request))
.await
.expect("Couldn't send message");
});
}

pub async fn run(&mut self) {
let semaphore = Arc::new(Semaphore::new(self.opts.concurrency_limit));

// recv() is protected by outside mpsc, no lock needed here
while let Some(message) = self.receiver.recv().await {
match message {
Message::Cancel(key) => {
debug!("Message::Cancel task: {:?}", key);
if let Err(error) = self.cancel_task(key).await {
error!("Failed to cancel task: {error}")
}
}
Message::Task(proof_request) => {
let permit = Arc::clone(&semaphore)
.acquire_owned()
.await
.expect("Couldn't acquire permit");
self.run_task(proof_request, permit).await;
debug!("Message::Task proof_request: {:?}", proof_request);
let running_task_count = self.running_tasks.lock().await.len();
if running_task_count < self.opts.concurrency_limit {
info!("Running task {:?}", proof_request);
self.run_task(proof_request).await;
} else {
info!(
"Task concurrency limit reached, current running {:?}, pending: {:?}",
running_task_count,
self.pending_tasks.lock().await.len()
);
let mut pending_tasks = self.pending_tasks.lock().await;
pending_tasks.push_back(proof_request);
}
}
Message::TaskComplete(req) => {
// pop up pending task if any task complete
debug!("Message::TaskComplete: {:?}", req);
info!(
"task completed, current running {:?}, pending: {:?}",
self.running_tasks.lock().await.len(),
self.pending_tasks.lock().await.len()
);
let mut pending_tasks = self.pending_tasks.lock().await;
if let Some(proof_request) = pending_tasks.pop_front() {
info!("Pop out pending task {:?}", proof_request);
self.sender
.send(Message::Task(proof_request))
.await
.expect("Couldn't send message");
}
}
}
}
Expand Down Expand Up @@ -189,7 +237,7 @@ pub async fn handle_proof(
store: Option<&mut TaskManagerWrapper>,
) -> HostResult<Proof> {
info!(
"# Generating proof for block {} on {}",
"Generating proof for block {} on {}",
proof_request.block_number, proof_request.network
);

Expand Down
Loading