Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(raiko): put the tasks that cannot run in parallel into pending list #358

Merged
merged 14 commits into from
Oct 14, 2024
5 changes: 3 additions & 2 deletions host/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ pub struct ProverState {
pub enum Message {
Cancel(TaskDescriptor),
Task(ProofRequest),
TaskComplete(ProofRequest),
smtmfft marked this conversation as resolved.
Show resolved Hide resolved
}

impl From<&ProofRequest> for Message {
Expand Down Expand Up @@ -184,9 +185,9 @@ impl ProverState {

let opts_clone = opts.clone();
let chain_specs_clone = chain_specs.clone();

let sender = task_channel.clone();
tokio::spawn(async move {
ProofActor::new(receiver, opts_clone, chain_specs_clone)
ProofActor::new(sender, receiver, opts_clone, chain_specs_clone)
.run()
.await;
});
Expand Down
88 changes: 68 additions & 20 deletions host/src/proof.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::{collections::HashMap, sync::Arc};
use std::{
collections::{HashMap, VecDeque},
sync::Arc,
};

use raiko_core::{
interfaces::{ProofRequest, RaikoError},
Expand All @@ -13,10 +16,13 @@ use raiko_lib::{
use raiko_tasks::{get_task_manager, TaskDescriptor, TaskManager, TaskManagerWrapper, TaskStatus};
use tokio::{
select,
sync::{mpsc::Receiver, Mutex, OwnedSemaphorePermit, Semaphore},
sync::{
mpsc::{Receiver, Sender},
Mutex,
},
};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
use tracing::{debug, error, info, warn};

use crate::{
cache,
Expand All @@ -32,26 +38,36 @@ use crate::{
pub struct ProofActor {
opts: Opts,
chain_specs: SupportedChainSpecs,
tasks: Arc<Mutex<HashMap<TaskDescriptor, CancellationToken>>>,
running_tasks: Arc<Mutex<HashMap<TaskDescriptor, CancellationToken>>>,
pending_tasks: Arc<Mutex<VecDeque<ProofRequest>>>,
receiver: Receiver<Message>,
sender: Sender<Message>,
}

impl ProofActor {
pub fn new(receiver: Receiver<Message>, opts: Opts, chain_specs: SupportedChainSpecs) -> Self {
let tasks = Arc::new(Mutex::new(
pub fn new(
sender: Sender<Message>,
receiver: Receiver<Message>,
opts: Opts,
chain_specs: SupportedChainSpecs,
) -> Self {
let running_tasks = Arc::new(Mutex::new(
HashMap::<TaskDescriptor, CancellationToken>::new(),
));
let pending_tasks = Arc::new(Mutex::new(VecDeque::<ProofRequest>::new()));

Self {
tasks,
opts,
chain_specs,
running_tasks,
pending_tasks,
receiver,
sender,
smtmfft marked this conversation as resolved.
Show resolved Hide resolved
}
}

pub async fn cancel_task(&mut self, key: TaskDescriptor) -> HostResult<()> {
let tasks_map = self.tasks.lock().await;
let tasks_map = self.running_tasks.lock().await;
let Some(task) = tasks_map.get(&key) else {
warn!("No task with those keys to cancel");
return Ok(());
Expand All @@ -76,7 +92,7 @@ impl ProofActor {
Ok(())
}

pub async fn run_task(&mut self, proof_request: ProofRequest, _permit: OwnedSemaphorePermit) {
pub async fn run_task(&mut self, proof_request: ProofRequest) {
let cancel_token = CancellationToken::new();

let Ok((chain_id, blockhash)) = get_task_data(
Expand All @@ -97,10 +113,11 @@ impl ProofActor {
proof_request.prover.clone().to_string(),
));

let mut tasks = self.tasks.lock().await;
let mut tasks = self.running_tasks.lock().await;
tasks.insert(key.clone(), cancel_token.clone());
let sender = self.sender.clone();

let tasks = self.tasks.clone();
let tasks = self.running_tasks.clone();
let opts = self.opts.clone();
let chain_specs = self.chain_specs.clone();

Expand All @@ -109,7 +126,7 @@ impl ProofActor {
_ = cancel_token.cancelled() => {
info!("Task cancelled");
}
result = Self::handle_message(proof_request, key.clone(), &opts, &chain_specs) => {
result = Self::handle_message(proof_request.clone(), key.clone(), &opts, &chain_specs) => {
match result {
Ok(()) => {
info!("Host handling message");
Expand All @@ -122,25 +139,56 @@ impl ProofActor {
}
let mut tasks = tasks.lock().await;
tasks.remove(&key);
// notify complete task to let next pending task run
sender
.send(Message::TaskComplete(proof_request))
.await
.expect("Couldn't send message");
});
}

pub async fn run(&mut self) {
let semaphore = Arc::new(Semaphore::new(self.opts.concurrency_limit));

// recv() is protected by outside mpsc, no lock needed here
while let Some(message) = self.receiver.recv().await {
match message {
Message::Cancel(key) => {
debug!("Message::Cancel task: {:?}", key);
if let Err(error) = self.cancel_task(key).await {
smtmfft marked this conversation as resolved.
Show resolved Hide resolved
error!("Failed to cancel task: {error}")
}
}
Message::Task(proof_request) => {
let permit = Arc::clone(&semaphore)
.acquire_owned()
.await
.expect("Couldn't acquire permit");
self.run_task(proof_request, permit).await;
debug!("Message::Task proof_request: {:?}", proof_request);
let running_task_count = self.running_tasks.lock().await.len();
smtmfft marked this conversation as resolved.
Show resolved Hide resolved
if running_task_count < self.opts.concurrency_limit {
info!("Running task {:?}", proof_request);
self.run_task(proof_request).await;
smtmfft marked this conversation as resolved.
Show resolved Hide resolved
} else {
info!(
"Task concurrency limit reached, current running {:?}, pending: {:?}",
running_task_count,
self.pending_tasks.lock().await.len()
smtmfft marked this conversation as resolved.
Show resolved Hide resolved
);
let mut pending_tasks = self.pending_tasks.lock().await;
pending_tasks.push_back(proof_request);
}
}
Message::TaskComplete(req) => {
// pop up pending task if any task complete
debug!("Message::TaskComplete: {:?}", req);
info!(
smtmfft marked this conversation as resolved.
Show resolved Hide resolved
"task completed, current running {:?}, pending: {:?}",
self.running_tasks.lock().await.len(),
self.pending_tasks.lock().await.len()
);
let mut pending_tasks = self.pending_tasks.lock().await;
if let Some(proof_request) = pending_tasks.pop_front() {
info!("Pop out pending task {:?}", proof_request);
self.sender
smtmfft marked this conversation as resolved.
Show resolved Hide resolved
.send(Message::Task(proof_request))
.await
.expect("Couldn't send message");
}
}
}
}
Expand Down Expand Up @@ -189,7 +237,7 @@ pub async fn handle_proof(
store: Option<&mut TaskManagerWrapper>,
) -> HostResult<Proof> {
info!(
"# Generating proof for block {} on {}",
"Generating proof for block {} on {}",
proof_request.block_number, proof_request.network
);

Expand Down
Loading