diff --git a/host/src/lib.rs b/host/src/lib.rs index 6927314b2..d0b2fb691 100644 --- a/host/src/lib.rs +++ b/host/src/lib.rs @@ -152,6 +152,7 @@ pub struct ProverState { pub enum Message { Cancel(TaskDescriptor), Task(ProofRequest), + TaskComplete(ProofRequest), CancelAggregate(AggregationOnlyRequest), Aggregate(AggregationOnlyRequest), } @@ -200,9 +201,9 @@ impl ProverState { let opts_clone = opts.clone(); let chain_specs_clone = chain_specs.clone(); - + let sender = task_channel.clone(); tokio::spawn(async move { - ProofActor::new(receiver, opts_clone, chain_specs_clone) + ProofActor::new(sender, receiver, opts_clone, chain_specs_clone) .run() .await; }); diff --git a/host/src/proof.rs b/host/src/proof.rs index 215a5b4f7..9223a5866 100644 --- a/host/src/proof.rs +++ b/host/src/proof.rs @@ -1,4 +1,8 @@ -use std::{collections::HashMap, str::FromStr, sync::Arc}; +use std::{ + collections::{HashMap, VecDeque}, + str::FromStr, + sync::Arc, +}; use anyhow::anyhow; use raiko_core::{ @@ -16,10 +20,13 @@ use raiko_tasks::{get_task_manager, TaskDescriptor, TaskManager, TaskManagerWrap use reth_primitives::B256; use tokio::{ select, - sync::{mpsc::Receiver, Mutex, OwnedSemaphorePermit, Semaphore}, + sync::{ + mpsc::{Receiver, Sender}, + Mutex, OwnedSemaphorePermit, Semaphore, + }, }; use tokio_util::sync::CancellationToken; -use tracing::{error, info, warn}; +use tracing::{debug, error, info, warn}; use crate::{ cache, @@ -35,32 +42,42 @@ use crate::{ pub struct ProofActor { opts: Opts, chain_specs: SupportedChainSpecs, - tasks: Arc>>, aggregate_tasks: Arc>>, + running_tasks: Arc>>, + pending_tasks: Arc>>, receiver: Receiver, + sender: Sender, } impl ProofActor { - pub fn new(receiver: Receiver, opts: Opts, chain_specs: SupportedChainSpecs) -> Self { - let tasks = Arc::new(Mutex::new( + pub fn new( + sender: Sender, + receiver: Receiver, + opts: Opts, + chain_specs: SupportedChainSpecs, + ) -> Self { + let running_tasks = Arc::new(Mutex::new( HashMap::::new(), )); let aggregate_tasks = Arc::new(Mutex::new(HashMap::< AggregationOnlyRequest, CancellationToken, >::new())); + let pending_tasks = Arc::new(Mutex::new(VecDeque::::new())); Self { - tasks, - aggregate_tasks, opts, chain_specs, + aggregate_tasks, + running_tasks, + pending_tasks, receiver, + sender, } } pub async fn cancel_task(&mut self, key: TaskDescriptor) -> HostResult<()> { - let tasks_map = self.tasks.lock().await; + let tasks_map = self.running_tasks.lock().await; let Some(task) = tasks_map.get(&key) else { warn!("No task with those keys to cancel"); return Ok(()); @@ -85,7 +102,7 @@ impl ProofActor { Ok(()) } - pub async fn run_task(&mut self, proof_request: ProofRequest, _permit: OwnedSemaphorePermit) { + pub async fn run_task(&mut self, proof_request: ProofRequest) { let cancel_token = CancellationToken::new(); let Ok((chain_id, blockhash)) = get_task_data( @@ -106,10 +123,11 @@ impl ProofActor { proof_request.prover.clone().to_string(), )); - let mut tasks = self.tasks.lock().await; + let mut tasks = self.running_tasks.lock().await; tasks.insert(key.clone(), cancel_token.clone()); + let sender = self.sender.clone(); - let tasks = self.tasks.clone(); + let tasks = self.running_tasks.clone(); let opts = self.opts.clone(); let chain_specs = self.chain_specs.clone(); @@ -118,7 +136,7 @@ impl ProofActor { _ = cancel_token.cancelled() => { info!("Task cancelled"); } - result = Self::handle_message(proof_request, key.clone(), &opts, &chain_specs) => { + result = Self::handle_message(proof_request.clone(), key.clone(), &opts, &chain_specs) => { match result { Ok(status) => { info!("Host handling message: {status:?}"); @@ -131,6 +149,11 @@ impl ProofActor { } let mut tasks = tasks.lock().await; tasks.remove(&key); + // notify complete task to let next pending task run + sender + .send(Message::TaskComplete(proof_request)) + .await + .expect("Couldn't send message"); }); } @@ -203,21 +226,47 @@ impl ProofActor { } pub async fn run(&mut self) { + // recv() is protected by outside mpsc, no lock needed here let semaphore = Arc::new(Semaphore::new(self.opts.concurrency_limit)); - while let Some(message) = self.receiver.recv().await { match message { Message::Cancel(key) => { + debug!("Message::Cancel task: {key:?}"); if let Err(error) = self.cancel_task(key).await { error!("Failed to cancel task: {error}") } } Message::Task(proof_request) => { - let permit = Arc::clone(&semaphore) - .acquire_owned() - .await - .expect("Couldn't acquire permit"); - self.run_task(proof_request, permit).await; + debug!("Message::Task proof_request: {proof_request:?}"); + let running_task_count = self.running_tasks.lock().await.len(); + if running_task_count < self.opts.concurrency_limit { + info!("Running task {proof_request:?}"); + self.run_task(proof_request).await; + } else { + info!( + "Task concurrency limit reached, current running {running_task_count:?}, pending: {:?}", + self.pending_tasks.lock().await.len() + ); + let mut pending_tasks = self.pending_tasks.lock().await; + pending_tasks.push_back(proof_request); + } + } + Message::TaskComplete(req) => { + // pop up pending task if any task complete + debug!("Message::TaskComplete: {req:?}"); + info!( + "task completed, current running {:?}, pending: {:?}", + self.running_tasks.lock().await.len(), + self.pending_tasks.lock().await.len() + ); + let mut pending_tasks = self.pending_tasks.lock().await; + if let Some(proof_request) = pending_tasks.pop_front() { + info!("Pop out pending task {proof_request:?}"); + self.sender + .send(Message::Task(proof_request)) + .await + .expect("Couldn't send message"); + } } Message::CancelAggregate(request) => { if let Err(error) = self.cancel_aggregation_task(request).await { @@ -326,7 +375,7 @@ pub async fn handle_proof( store: Option<&mut TaskManagerWrapper>, ) -> HostResult { info!( - "# Generating proof for block {} on {}", + "Generating proof for block {} on {}", proof_request.block_number, proof_request.network );