diff --git a/Cargo.lock b/Cargo.lock index d47ec8685..590a56b4f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8285,6 +8285,7 @@ dependencies = [ "alloy-transport-http", "anyhow", "bincode", + "hex", "once_cell", "pem 3.0.4", "raiko-lib", diff --git a/core/src/interfaces.rs b/core/src/interfaces.rs index 3099565ed..63ad41140 100644 --- a/core/src/interfaces.rs +++ b/core/src/interfaces.rs @@ -3,13 +3,15 @@ use alloy_primitives::{Address, B256}; use clap::{Args, ValueEnum}; use raiko_lib::{ consts::VerifierType, - input::{BlobProofType, GuestInput, GuestOutput}, + input::{ + AggregationGuestInput, AggregationGuestOutput, BlobProofType, GuestInput, GuestOutput, + }, prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverError}, }; use serde::{Deserialize, Serialize}; use serde_json::Value; use serde_with::{serde_as, DisplayFromStr}; -use std::{collections::HashMap, path::Path, str::FromStr}; +use std::{collections::HashMap, fmt::Display, path::Path, str::FromStr}; use utoipa::ToSchema; #[derive(Debug, thiserror::Error, ToSchema)] @@ -203,6 +205,47 @@ impl ProofType { } } + /// Run the prover driver depending on the proof type. + pub async fn aggregate_proofs( + &self, + input: AggregationGuestInput, + output: &AggregationGuestOutput, + config: &Value, + store: Option<&mut dyn IdWrite>, + ) -> RaikoResult { + let proof = match self { + ProofType::Native => NativeProver::aggregate(input.clone(), output, config, store) + .await + .map_err(>::into), + ProofType::Sp1 => { + #[cfg(feature = "sp1")] + return sp1_driver::Sp1Prover::aggregate(input.clone(), output, config, store) + .await + .map_err(|e| e.into()); + #[cfg(not(feature = "sp1"))] + Err(RaikoError::FeatureNotSupportedError(*self)) + } + ProofType::Risc0 => { + #[cfg(feature = "risc0")] + return risc0_driver::Risc0Prover::aggregate(input.clone(), output, config, store) + .await + .map_err(|e| e.into()); + #[cfg(not(feature = "risc0"))] + Err(RaikoError::FeatureNotSupportedError(*self)) + } + ProofType::Sgx => { + #[cfg(feature = "sgx")] + return sgx_prover::SgxProver::aggregate(input.clone(), output, config, store) + .await + .map_err(|e| e.into()); + #[cfg(not(feature = "sgx"))] + Err(RaikoError::FeatureNotSupportedError(*self)) + } + }?; + + Ok(proof) + } + pub async fn cancel_proof( &self, proof_key: ProofKey, @@ -302,7 +345,7 @@ pub struct ProofRequestOpt { pub prover_args: ProverSpecificOpts, } -#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema, Args)] +#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema, Args, PartialEq, Eq, Hash)] pub struct ProverSpecificOpts { /// Native prover specific options. pub native: Option, @@ -398,3 +441,123 @@ impl TryFrom for ProofRequest { }) } } + +#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema)] +#[serde(default)] +/// A request for proof aggregation of multiple proofs. +pub struct AggregationRequest { + /// The block numbers and l1 inclusion block numbers for the blocks to aggregate proofs for. + pub block_numbers: Vec<(u64, Option)>, + /// The network to generate the proof for. + pub network: Option, + /// The L1 network to generate the proof for. + pub l1_network: Option, + // Graffiti. + pub graffiti: Option, + /// The protocol instance data. + pub prover: Option, + /// The proof type. + pub proof_type: Option, + /// Blob proof type. + pub blob_proof_type: Option, + #[serde(flatten)] + /// Any additional prover params in JSON format. + pub prover_args: ProverSpecificOpts, +} + +impl AggregationRequest { + /// Merge proof request options into aggregation request options. + pub fn merge(&mut self, opts: &ProofRequestOpt) -> RaikoResult<()> { + let this = serde_json::to_value(&self)?; + let mut opts = serde_json::to_value(opts)?; + merge(&mut opts, &this); + *self = serde_json::from_value(opts)?; + Ok(()) + } +} + +impl From for Vec { + fn from(value: AggregationRequest) -> Self { + value + .block_numbers + .iter() + .map( + |&(block_number, l1_inclusion_block_number)| ProofRequestOpt { + block_number: Some(block_number), + l1_inclusion_block_number, + network: value.network.clone(), + l1_network: value.l1_network.clone(), + graffiti: value.graffiti.clone(), + prover: value.prover.clone(), + proof_type: value.proof_type.clone(), + blob_proof_type: value.blob_proof_type.clone(), + prover_args: value.prover_args.clone(), + }, + ) + .collect() + } +} + +impl From for AggregationRequest { + fn from(value: ProofRequestOpt) -> Self { + let block_numbers = if let Some(block_number) = value.block_number { + vec![(block_number, value.l1_inclusion_block_number)] + } else { + vec![] + }; + + Self { + block_numbers, + network: value.network, + l1_network: value.l1_network, + graffiti: value.graffiti, + prover: value.prover, + proof_type: value.proof_type, + blob_proof_type: value.blob_proof_type, + prover_args: value.prover_args, + } + } +} + +#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema, PartialEq, Eq, Hash)] +#[serde(default)] +/// A request for proof aggregation of multiple proofs. +pub struct AggregationOnlyRequest { + /// The block numbers and l1 inclusion block numbers for the blocks to aggregate proofs for. + pub proofs: Vec, + /// The proof type. + pub proof_type: Option, + #[serde(flatten)] + /// Any additional prover params in JSON format. + pub prover_args: ProverSpecificOpts, +} + +impl Display for AggregationOnlyRequest { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&format!( + "AggregationOnlyRequest {{ {:?}, {:?} }}", + self.proof_type, self.prover_args + )) + } +} + +impl From<(AggregationRequest, Vec)> for AggregationOnlyRequest { + fn from((request, proofs): (AggregationRequest, Vec)) -> Self { + Self { + proofs, + proof_type: request.proof_type, + prover_args: request.prover_args, + } + } +} + +impl AggregationOnlyRequest { + /// Merge proof request options into aggregation request options. + pub fn merge(&mut self, opts: &ProofRequestOpt) -> RaikoResult<()> { + let this = serde_json::to_value(&self)?; + let mut opts = serde_json::to_value(opts)?; + merge(&mut opts, &this); + *self = serde_json::from_value(opts)?; + Ok(()) + } +} diff --git a/core/src/lib.rs b/core/src/lib.rs index cd026952b..48064e326 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -226,8 +226,9 @@ mod tests { use clap::ValueEnum; use raiko_lib::{ consts::{Network, SupportedChainSpecs}, - input::BlobProofType, + input::{AggregationGuestInput, AggregationGuestOutput, BlobProofType}, primitives::B256, + prover::Proof, }; use serde_json::{json, Value}; use std::{collections::HashMap, env}; @@ -242,7 +243,7 @@ mod tests { ci == "1" } - fn test_proof_params() -> HashMap { + fn test_proof_params(enable_aggregation: bool) -> HashMap { let mut prover_args = HashMap::new(); prover_args.insert( "native".to_string(), @@ -256,7 +257,7 @@ mod tests { "sp1".to_string(), json! { { - "recursion": "core", + "recursion": if enable_aggregation { "compressed" } else { "plonk" }, "prover": "mock", "verify": true } @@ -278,8 +279,8 @@ mod tests { json! { { "instance_id": 121, - "setup": true, - "bootstrap": true, + "setup": enable_aggregation, + "bootstrap": enable_aggregation, "prove": true, } }, @@ -291,7 +292,7 @@ mod tests { l1_chain_spec: ChainSpec, taiko_chain_spec: ChainSpec, proof_request: ProofRequest, - ) { + ) -> Proof { let provider = RpcBlockDataProvider::new(&taiko_chain_spec.rpc, proof_request.block_number - 1) .expect("Could not create RpcBlockDataProvider"); @@ -301,10 +302,10 @@ mod tests { .await .expect("input generation failed"); let output = raiko.get_output(&input).expect("output generation failed"); - let _proof = raiko + raiko .prove(input, &output, None) .await - .expect("proof generation failed"); + .expect("proof generation failed") } #[ignore] @@ -332,7 +333,7 @@ mod tests { l1_network, proof_type, blob_proof_type: BlobProofType::ProofOfEquivalence, - prover_args: test_proof_params(), + prover_args: test_proof_params(false), }; prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await; } @@ -361,7 +362,7 @@ mod tests { l1_network, proof_type, blob_proof_type: BlobProofType::ProofOfEquivalence, - prover_args: test_proof_params(), + prover_args: test_proof_params(false), }; prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await; } @@ -399,7 +400,7 @@ mod tests { l1_network, proof_type, blob_proof_type: BlobProofType::ProofOfEquivalence, - prover_args: test_proof_params(), + prover_args: test_proof_params(false), }; prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await; } @@ -432,9 +433,55 @@ mod tests { l1_network, proof_type, blob_proof_type: BlobProofType::ProofOfEquivalence, - prover_args: test_proof_params(), + prover_args: test_proof_params(false), }; prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await; } } + + #[tokio::test(flavor = "multi_thread")] + async fn test_prove_block_taiko_a7_aggregated() { + let proof_type = get_proof_type_from_env(); + let l1_network = Network::Holesky.to_string(); + let network = Network::TaikoA7.to_string(); + // Give the CI an simpler block to test because it doesn't have enough memory. + // Unfortunately that also means that kzg is not getting fully verified by CI. + let block_number = if is_ci() { 105987 } else { 101368 }; + let taiko_chain_spec = SupportedChainSpecs::default() + .get_chain_spec(&network) + .unwrap(); + let l1_chain_spec = SupportedChainSpecs::default() + .get_chain_spec(&l1_network) + .unwrap(); + + let proof_request = ProofRequest { + block_number, + l1_inclusion_block_number: 0, + network, + graffiti: B256::ZERO, + prover: Address::ZERO, + l1_network, + proof_type, + blob_proof_type: BlobProofType::ProofOfEquivalence, + prover_args: test_proof_params(true), + }; + let proof = prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await; + + let input = AggregationGuestInput { + proofs: vec![proof.clone(), proof], + }; + + let output = AggregationGuestOutput { hash: B256::ZERO }; + + let aggregated_proof = proof_type + .aggregate_proofs( + input, + &output, + &serde_json::to_value(&test_proof_params(false)).unwrap(), + None, + ) + .await + .expect("proof aggregation failed"); + println!("aggregated proof: {aggregated_proof:?}"); + } } diff --git a/core/src/preflight/util.rs b/core/src/preflight/util.rs index 889134d94..10fb6394c 100644 --- a/core/src/preflight/util.rs +++ b/core/src/preflight/util.rs @@ -136,11 +136,8 @@ pub async fn prepare_taiko_chain_input( RaikoError::Preflight("No L1 inclusion block hash for the requested block".to_owned()) })?; info!( - "L1 inclusion block number: {:?}, hash: {:?}. L1 state block number: {:?}, hash: {:?}", - l1_inclusion_block_number, - l1_inclusion_block_hash, + "L1 inclusion block number: {l1_inclusion_block_number:?}, hash: {l1_inclusion_block_hash:?}. L1 state block number: {:?}, hash: {l1_state_block_hash:?}", l1_state_header.number, - l1_state_block_hash ); // Fetch the tx data from either calldata or blobdata diff --git a/core/src/prover.rs b/core/src/prover.rs index 577c5318a..de89d859e 100644 --- a/core/src/prover.rs +++ b/core/src/prover.rs @@ -58,14 +58,28 @@ impl Prover for NativeProver { } Ok(Proof { + input: None, proof: None, quote: None, + uuid: None, + kzg_proof: None, }) } async fn cancel(_proof_key: ProofKey, _read: Box<&mut dyn IdStore>) -> ProverResult<()> { Ok(()) } + + async fn aggregate( + _input: raiko_lib::input::AggregationGuestInput, + _output: &raiko_lib::input::AggregationGuestOutput, + _config: &ProverConfig, + _store: Option<&mut dyn IdWrite>, + ) -> ProverResult { + Ok(Proof { + ..Default::default() + }) + } } #[ignore = "Only used to test serialized data"] diff --git a/host/src/cache.rs b/host/src/cache.rs index c4cd99815..52fe34a53 100644 --- a/host/src/cache.rs +++ b/host/src/cache.rs @@ -55,10 +55,7 @@ pub async fn validate_input( let cached_block_hash = cache_input.block.header.hash_slow(); let real_block_hash = block.header.hash.unwrap(); - debug!( - "cache_block_hash={:?}, real_block_hash={:?}", - cached_block_hash, real_block_hash - ); + debug!("cache_block_hash={cached_block_hash:?}, real_block_hash={real_block_hash:?}"); // double check if cache is valid if cached_block_hash == real_block_hash { diff --git a/host/src/interfaces.rs b/host/src/interfaces.rs index 728d7710a..330446ef4 100644 --- a/host/src/interfaces.rs +++ b/host/src/interfaces.rs @@ -121,12 +121,12 @@ impl From for TaskStatus { | HostError::JoinHandle(_) | HostError::InvalidAddress(_) | HostError::InvalidRequestConfig(_) => unreachable!(), - HostError::Conversion(_) - | HostError::Serde(_) - | HostError::Core(_) - | HostError::Anyhow(_) - | HostError::FeatureNotSupportedError(_) - | HostError::Io(_) => TaskStatus::UnspecifiedFailureReason, + HostError::Conversion(e) => TaskStatus::NonDbFailure(e), + HostError::Serde(e) => TaskStatus::NonDbFailure(e.to_string()), + HostError::Core(e) => TaskStatus::NonDbFailure(e.to_string()), + HostError::Anyhow(e) => TaskStatus::NonDbFailure(e.to_string()), + HostError::FeatureNotSupportedError(e) => TaskStatus::NonDbFailure(e.to_string()), + HostError::Io(e) => TaskStatus::NonDbFailure(e.to_string()), HostError::RPC(_) => TaskStatus::NetworkFailure, HostError::Guest(_) => TaskStatus::ProofFailure_Generic, HostError::TaskManager(_) => TaskStatus::SqlDbCorruption, @@ -142,12 +142,12 @@ impl From<&HostError> for TaskStatus { | HostError::JoinHandle(_) | HostError::InvalidAddress(_) | HostError::InvalidRequestConfig(_) => unreachable!(), - HostError::Conversion(_) - | HostError::Serde(_) - | HostError::Core(_) - | HostError::Anyhow(_) - | HostError::FeatureNotSupportedError(_) - | HostError::Io(_) => TaskStatus::UnspecifiedFailureReason, + HostError::Conversion(e) => TaskStatus::NonDbFailure(e.to_owned()), + HostError::Serde(e) => TaskStatus::NonDbFailure(e.to_string()), + HostError::Core(e) => TaskStatus::NonDbFailure(e.to_string()), + HostError::Anyhow(e) => TaskStatus::NonDbFailure(e.to_string()), + HostError::FeatureNotSupportedError(e) => TaskStatus::NonDbFailure(e.to_string()), + HostError::Io(e) => TaskStatus::NonDbFailure(e.to_string()), HostError::RPC(_) => TaskStatus::NetworkFailure, HostError::Guest(_) => TaskStatus::ProofFailure_Generic, HostError::TaskManager(_) => TaskStatus::SqlDbCorruption, diff --git a/host/src/lib.rs b/host/src/lib.rs index a4df64dc9..6927314b2 100644 --- a/host/src/lib.rs +++ b/host/src/lib.rs @@ -4,7 +4,7 @@ use anyhow::Context; use cap::Cap; use clap::Parser; use raiko_core::{ - interfaces::{ProofRequest, ProofRequestOpt}, + interfaces::{AggregationOnlyRequest, ProofRequest, ProofRequestOpt}, merge, }; use raiko_lib::consts::SupportedChainSpecs; @@ -152,6 +152,8 @@ pub struct ProverState { pub enum Message { Cancel(TaskDescriptor), Task(ProofRequest), + CancelAggregate(AggregationOnlyRequest), + Aggregate(AggregationOnlyRequest), } impl From<&ProofRequest> for Message { @@ -166,6 +168,12 @@ impl From<&TaskDescriptor> for Message { } } +impl From for Message { + fn from(value: AggregationOnlyRequest) -> Self { + Self::Aggregate(value) + } +} + impl ProverState { pub fn init() -> HostResult { // Read the command line arguments; diff --git a/host/src/proof.rs b/host/src/proof.rs index 31a56e72a..215a5b4f7 100644 --- a/host/src/proof.rs +++ b/host/src/proof.rs @@ -1,16 +1,19 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, str::FromStr, sync::Arc}; +use anyhow::anyhow; use raiko_core::{ - interfaces::{ProofRequest, RaikoError}, + interfaces::{AggregationOnlyRequest, ProofRequest, ProofType, RaikoError}, provider::{get_task_data, rpc::RpcBlockDataProvider}, Raiko, }; use raiko_lib::{ consts::SupportedChainSpecs, + input::{AggregationGuestInput, AggregationGuestOutput}, prover::{IdWrite, Proof}, Measurement, }; use raiko_tasks::{get_task_manager, TaskDescriptor, TaskManager, TaskManagerWrapper, TaskStatus}; +use reth_primitives::B256; use tokio::{ select, sync::{mpsc::Receiver, Mutex, OwnedSemaphorePermit, Semaphore}, @@ -33,6 +36,7 @@ pub struct ProofActor { opts: Opts, chain_specs: SupportedChainSpecs, tasks: Arc>>, + aggregate_tasks: Arc>>, receiver: Receiver, } @@ -41,9 +45,14 @@ impl ProofActor { let tasks = Arc::new(Mutex::new( HashMap::::new(), )); + let aggregate_tasks = Arc::new(Mutex::new(HashMap::< + AggregationOnlyRequest, + CancellationToken, + >::new())); Self { tasks, + aggregate_tasks, opts, chain_specs, receiver, @@ -125,6 +134,74 @@ impl ProofActor { }); } + pub async fn cancel_aggregation_task( + &mut self, + request: AggregationOnlyRequest, + ) -> HostResult<()> { + let tasks_map = self.aggregate_tasks.lock().await; + let Some(task) = tasks_map.get(&request) else { + warn!("No task with those keys to cancel"); + return Ok(()); + }; + + // TODO:(petar) implement cancel_proof_aggregation + // let mut manager = get_task_manager(&self.opts.clone().into()); + // let proof_type = ProofType::from_str( + // request + // .proof_type + // .as_ref() + // .ok_or_else(|| anyhow!("No proof type"))?, + // )?; + // proof_type + // .cancel_proof_aggregation(request, Box::new(&mut manager)) + // .await + // .or_else(|e| { + // if e.to_string().contains("No data for query") { + // warn!("Task already cancelled or not yet started!"); + // Ok(()) + // } else { + // Err::<(), HostError>(e.into()) + // } + // })?; + task.cancel(); + Ok(()) + } + + pub async fn run_aggregate( + &mut self, + request: AggregationOnlyRequest, + _permit: OwnedSemaphorePermit, + ) { + let cancel_token = CancellationToken::new(); + + let mut tasks = self.aggregate_tasks.lock().await; + tasks.insert(request.clone(), cancel_token.clone()); + + let request_clone = request.clone(); + let tasks = self.aggregate_tasks.clone(); + let opts = self.opts.clone(); + + tokio::spawn(async move { + select! { + _ = cancel_token.cancelled() => { + info!("Task cancelled"); + } + result = Self::handle_aggregate(request_clone, &opts) => { + match result { + Ok(()) => { + info!("Host handling message"); + } + Err(error) => { + error!("Worker failed due to: {error:?}"); + } + }; + } + } + let mut tasks = tasks.lock().await; + tasks.remove(&request); + }); + } + pub async fn run(&mut self) { let semaphore = Arc::new(Semaphore::new(self.opts.concurrency_limit)); @@ -142,6 +219,18 @@ impl ProofActor { .expect("Couldn't acquire permit"); self.run_task(proof_request, permit).await; } + Message::CancelAggregate(request) => { + if let Err(error) = self.cancel_aggregation_task(request).await { + error!("Failed to cancel task: {error}") + } + } + Message::Aggregate(request) => { + let permit = Arc::clone(&semaphore) + .acquire_owned() + .await + .expect("Couldn't acquire permit"); + self.run_aggregate(request, permit).await; + } } } } @@ -158,7 +247,7 @@ impl ProofActor { if let Some(latest_status) = status.iter().last() { if !matches!(latest_status.0, TaskStatus::Registered) { - return Ok(latest_status.0); + return Ok(latest_status.0.clone()); } } @@ -176,11 +265,58 @@ impl ProofActor { }; manager - .update_task_progress(key, status, proof.as_deref()) + .update_task_progress(key, status.clone(), proof.as_deref()) .await .map_err(HostError::from)?; Ok(status) } + + pub async fn handle_aggregate(request: AggregationOnlyRequest, opts: &Opts) -> HostResult<()> { + let mut manager = get_task_manager(&opts.clone().into()); + + let status = manager + .get_aggregation_task_proving_status(&request) + .await?; + + if let Some(latest_status) = status.iter().last() { + if !matches!(latest_status.0, TaskStatus::Registered) { + return Ok(()); + } + } + + manager + .update_aggregation_task_progress(&request, TaskStatus::WorkInProgress, None) + .await?; + let proof_type = ProofType::from_str( + request + .proof_type + .as_ref() + .ok_or_else(|| anyhow!("No proof type"))?, + )?; + let input = AggregationGuestInput { + proofs: request.clone().proofs, + }; + let output = AggregationGuestOutput { hash: B256::ZERO }; + let config = serde_json::to_value(request.clone().prover_args)?; + let mut manager = get_task_manager(&opts.clone().into()); + + let (status, proof) = match proof_type + .aggregate_proofs(input, &output, &config, Some(&mut manager)) + .await + { + Err(error) => { + error!("{error}"); + (HostError::from(error).into(), None) + } + Ok(proof) => (TaskStatus::Success, Some(serde_json::to_vec(&proof)?)), + }; + + manager + .update_aggregation_task_progress(&request, status, proof.as_deref()) + .await?; + + Ok(()) + } } pub async fn handle_proof( diff --git a/host/src/server/api/mod.rs b/host/src/server/api/mod.rs index 4aa8e0981..45be92f15 100644 --- a/host/src/server/api/mod.rs +++ b/host/src/server/api/mod.rs @@ -18,6 +18,7 @@ use crate::ProverState; pub mod v1; pub mod v2; +pub mod v3; pub fn create_router(concurrency_limit: usize, jwt_secret: Option<&str>) -> Router { let cors = CorsLayer::new() @@ -37,11 +38,13 @@ pub fn create_router(concurrency_limit: usize, jwt_secret: Option<&str>) -> Rout let v1_api = v1::create_router(concurrency_limit); let v2_api = v2::create_router(); + let v3_api = v3::create_router(); let router = Router::new() .nest("/v1", v1_api) - .nest("/v2", v2_api.clone()) - .merge(v2_api) + .nest("/v2", v2_api) + .nest("/v3", v3_api.clone()) + .merge(v3_api) .layer(middleware) .layer(middleware::from_fn(check_max_body_size)) .layer(trace) @@ -58,7 +61,7 @@ pub fn create_router(concurrency_limit: usize, jwt_secret: Option<&str>) -> Rout } pub fn create_docs() -> utoipa::openapi::OpenApi { - v2::create_docs() + v3::create_docs() } async fn check_max_body_size(req: Request, next: Next) -> Response { diff --git a/host/src/server/api/v2/mod.rs b/host/src/server/api/v2/mod.rs index 6985369bc..f4fc046a7 100644 --- a/host/src/server/api/v2/mod.rs +++ b/host/src/server/api/v2/mod.rs @@ -11,7 +11,7 @@ use crate::{ ProverState, }; -mod proof; +pub mod proof; #[derive(OpenApi)] #[openapi( @@ -157,6 +157,8 @@ pub fn create_router() -> Router { // Only add the concurrency limit to the proof route. We want to still be able to call // healthchecks and metrics to have insight into the system. .nest("/proof", proof::create_router()) + // TODO: Separate task or try to get it into /proof somehow? Probably separate + .nest("/aggregate", proof::create_router()) .nest("/health", v1::health::create_router()) .nest("/metrics", v1::metrics::create_router()) .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", docs.clone())) diff --git a/host/src/server/api/v2/proof/mod.rs b/host/src/server/api/v2/proof/mod.rs index ce089375c..d57335cdf 100644 --- a/host/src/server/api/v2/proof/mod.rs +++ b/host/src/server/api/v2/proof/mod.rs @@ -11,10 +11,10 @@ use crate::{ Message, ProverState, }; -mod cancel; -mod list; -mod prune; -mod report; +pub mod cancel; +pub mod list; +pub mod prune; +pub mod report; #[utoipa::path(post, path = "/proof", tag = "Proving", @@ -98,7 +98,7 @@ async fn proof_handler( Ok(proof.into()) } // For all other statuses just return the status. - status => Ok((*status).into()), + status => Ok(status.clone().into()), } } diff --git a/host/src/server/api/v3/mod.rs b/host/src/server/api/v3/mod.rs new file mode 100644 index 000000000..faf46b61e --- /dev/null +++ b/host/src/server/api/v3/mod.rs @@ -0,0 +1,172 @@ +use axum::{response::IntoResponse, Json, Router}; +use raiko_lib::prover::Proof; +use raiko_tasks::TaskStatus; +use serde::{Deserialize, Serialize}; +use utoipa::{OpenApi, ToSchema}; +use utoipa_scalar::{Scalar, Servable}; +use utoipa_swagger_ui::SwaggerUi; + +use crate::{ + server::api::v1::{self, GuestOutputDoc}, + ProverState, +}; + +mod proof; + +#[derive(OpenApi)] +#[openapi( + info( + title = "Raiko Proverd Server API", + version = "3.0", + description = "Raiko Proverd Server API", + contact( + name = "API Support", + url = "https://community.taiko.xyz", + email = "info@taiko.xyz", + ), + license( + name = "MIT", + url = "https://github.com/taikoxyz/raiko/blob/main/LICENSE" + ), + ), + components( + schemas( + raiko_core::interfaces::ProofRequestOpt, + raiko_core::interfaces::ProverSpecificOpts, + crate::interfaces::HostError, + GuestOutputDoc, + ProofResponse, + TaskStatus, + CancelStatus, + PruneStatus, + Proof, + Status, + ) + ), + tags( + (name = "Proving", description = "Routes that handle proving requests"), + (name = "Health", description = "Routes that report the server health status"), + (name = "Metrics", description = "Routes that give detailed insight into the server") + ) +)] +/// The root API struct which is generated from the `OpenApi` derive macro. +pub struct Docs; + +#[derive(Debug, Deserialize, Serialize, ToSchema)] +#[serde(untagged)] +pub enum ProofResponse { + Status { + /// The status of the submitted task. + status: TaskStatus, + }, + Proof { + /// The proof. + proof: Proof, + }, +} + +#[derive(Debug, Deserialize, Serialize, ToSchema)] +#[serde(tag = "status", rename_all = "lowercase")] +pub enum Status { + Ok { data: ProofResponse }, + Error { error: String, message: String }, +} + +impl From> for Status { + fn from(proof: Vec) -> Self { + Self::Ok { + data: ProofResponse::Proof { + proof: serde_json::from_slice(&proof).unwrap_or_default(), + }, + } + } +} + +impl From for Status { + fn from(proof: Proof) -> Self { + Self::Ok { + data: ProofResponse::Proof { proof }, + } + } +} + +impl From for Status { + fn from(status: TaskStatus) -> Self { + match status { + TaskStatus::Success | TaskStatus::WorkInProgress | TaskStatus::Registered => Self::Ok { + data: ProofResponse::Status { status }, + }, + _ => Self::Error { + error: "task_failed".to_string(), + message: format!("Task failed with status: {status:?}"), + }, + } + } +} + +impl IntoResponse for Status { + fn into_response(self) -> axum::response::Response { + Json(serde_json::to_value(self).unwrap()).into_response() + } +} + +#[derive(Debug, Deserialize, Serialize, ToSchema)] +#[serde(tag = "status", rename_all = "lowercase")] +/// Status of cancellation request. +/// Can be `ok` for a successful cancellation or `error` with message and error type for errors. +pub enum CancelStatus { + /// Cancellation was successful. + Ok, + /// Cancellation failed. + Error { error: String, message: String }, +} + +impl IntoResponse for CancelStatus { + fn into_response(self) -> axum::response::Response { + Json(serde_json::to_value(self).unwrap()).into_response() + } +} + +#[derive(Debug, Serialize, ToSchema, Deserialize)] +#[serde(tag = "status", rename_all = "lowercase")] +/// Status of prune request. +/// Can be `ok` for a successful prune or `error` with message and error type for errors. +pub enum PruneStatus { + /// Prune was successful. + Ok, + /// Prune failed. + Error { error: String, message: String }, +} + +impl IntoResponse for PruneStatus { + fn into_response(self) -> axum::response::Response { + Json(serde_json::to_value(self).unwrap()).into_response() + } +} + +#[must_use] +pub fn create_docs() -> utoipa::openapi::OpenApi { + [ + v1::health::create_docs(), + v1::metrics::create_docs(), + proof::create_docs(), + ] + .into_iter() + .fold(Docs::openapi(), |mut doc, sub_doc| { + doc.merge(sub_doc); + doc + }) +} + +pub fn create_router() -> Router { + let docs = create_docs(); + + Router::new() + // Only add the concurrency limit to the proof route. We want to still be able to call + // healthchecks and metrics to have insight into the system. + .nest("/proof", proof::create_router()) + .nest("/health", v1::health::create_router()) + .nest("/metrics", v1::metrics::create_router()) + .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", docs.clone())) + .merge(Scalar::with_url("/scalar", docs)) +} diff --git a/host/src/server/api/v3/proof/aggregate.rs b/host/src/server/api/v3/proof/aggregate.rs new file mode 100644 index 000000000..3bbffa00f --- /dev/null +++ b/host/src/server/api/v3/proof/aggregate.rs @@ -0,0 +1,114 @@ +use std::str::FromStr; + +use axum::{debug_handler, extract::State, routing::post, Json, Router}; +use raiko_core::interfaces::{AggregationOnlyRequest, ProofType}; +use raiko_tasks::{TaskManager, TaskStatus}; +use utoipa::OpenApi; + +use crate::{ + interfaces::HostResult, + metrics::{inc_current_req, inc_guest_req_count, inc_host_req_count}, + server::api::v3::Status, + Message, ProverState, +}; + +#[utoipa::path(post, path = "/proof/aggregate", + tag = "Proving", + request_body = AggregationRequest, + responses ( + (status = 200, description = "Successfully submitted proof aggregation task, queried aggregation tasks in progress or retrieved aggregated proof.", body = Status) + ) +)] +#[debug_handler(state = ProverState)] +/// Submit a proof aggregation task with requested config, get task status or get proof value. +/// +/// Accepts a proof request and creates a proving task with the specified guest prover. +/// The guest provers currently available are: +/// - native - constructs a block and checks for equality +/// - sgx - uses the sgx environment to construct a block and produce proof of execution +/// - sp1 - uses the sp1 prover +/// - risc0 - uses the risc0 prover +async fn aggregation_handler( + State(prover_state): State, + Json(mut aggregation_request): Json, +) -> HostResult { + inc_current_req(); + // Override the existing proof request config from the config file and command line + // options with the request from the client. + aggregation_request.merge(&prover_state.request_config())?; + + let proof_type = ProofType::from_str( + aggregation_request + .proof_type + .as_deref() + .unwrap_or_default(), + )?; + inc_host_req_count(0); + inc_guest_req_count(&proof_type, 0); + + if aggregation_request.proofs.is_empty() { + return Err(anyhow::anyhow!("No proofs provided").into()); + } + + let mut manager = prover_state.task_manager(); + + let status = manager + .get_aggregation_task_proving_status(&aggregation_request) + .await?; + + let Some((latest_status, ..)) = status.last() else { + // If there are no tasks with provided config, create a new one. + manager + .enqueue_aggregation_task(&aggregation_request) + .await?; + + prover_state + .task_channel + .try_send(Message::from(aggregation_request.clone()))?; + return Ok(Status::from(TaskStatus::Registered)); + }; + + match latest_status { + // If task has been cancelled add it to the queue again + TaskStatus::Cancelled + | TaskStatus::Cancelled_Aborted + | TaskStatus::Cancelled_NeverStarted + | TaskStatus::CancellationInProgress => { + manager + .update_aggregation_task_progress( + &aggregation_request, + TaskStatus::Registered, + None, + ) + .await?; + + prover_state + .task_channel + .try_send(Message::from(aggregation_request))?; + + Ok(Status::from(TaskStatus::Registered)) + } + // If the task has succeeded, return the proof. + TaskStatus::Success => { + let proof = manager + .get_aggregation_task_proof(&aggregation_request) + .await?; + + Ok(proof.into()) + } + // For all other statuses just return the status. + status => Ok(status.clone().into()), + } +} + +#[derive(OpenApi)] +#[openapi(paths(aggregation_handler))] +struct Docs; + +pub fn create_docs() -> utoipa::openapi::OpenApi { + Docs::openapi() +} + +pub fn create_router() -> Router { + Router::new().route("/", post(aggregation_handler)) +} diff --git a/host/src/server/api/v3/proof/cancel.rs b/host/src/server/api/v3/proof/cancel.rs new file mode 100644 index 000000000..6e721c716 --- /dev/null +++ b/host/src/server/api/v3/proof/cancel.rs @@ -0,0 +1,76 @@ +use axum::{debug_handler, extract::State, routing::post, Json, Router}; +use raiko_core::{ + interfaces::{AggregationRequest, ProofRequest, ProofRequestOpt}, + provider::get_task_data, +}; +use raiko_tasks::{TaskDescriptor, TaskManager, TaskStatus}; +use utoipa::OpenApi; + +use crate::{interfaces::HostResult, server::api::v2::CancelStatus, Message, ProverState}; + +#[utoipa::path(post, path = "/proof/cancel", + tag = "Proving", + request_body = ProofRequestOpt, + responses ( + (status = 200, description = "Successfully cancelled proof task", body = CancelStatus) + ) +)] +#[debug_handler(state = ProverState)] +/// Cancel a proof task with requested config. +/// +/// Accepts a proof request and cancels a proving task with the specified guest prover. +/// The guest provers currently available are: +/// - native - constructs a block and checks for equality +/// - sgx - uses the sgx environment to construct a block and produce proof of execution +/// - sp1 - uses the sp1 prover +/// - risc0 - uses the risc0 prover +async fn cancel_handler( + State(prover_state): State, + Json(mut aggregation_request): Json, +) -> HostResult { + // Override the existing proof request config from the config file and command line + // options with the request from the client. + aggregation_request.merge(&prover_state.request_config())?; + + let proof_request_opts: Vec = aggregation_request.into(); + + for opt in proof_request_opts { + let proof_request = ProofRequest::try_from(opt)?; + + let (chain_id, block_hash) = get_task_data( + &proof_request.network, + proof_request.block_number, + &prover_state.chain_specs, + ) + .await?; + + let key = TaskDescriptor::from(( + chain_id, + block_hash, + proof_request.proof_type, + proof_request.prover.clone().to_string(), + )); + + prover_state.task_channel.try_send(Message::from(&key))?; + + let mut manager = prover_state.task_manager(); + + manager + .update_task_progress(key, TaskStatus::Cancelled, None) + .await?; + } + + Ok(CancelStatus::Ok) +} + +#[derive(OpenApi)] +#[openapi(paths(cancel_handler))] +struct Docs; + +pub fn create_docs() -> utoipa::openapi::OpenApi { + Docs::openapi() +} + +pub fn create_router() -> Router { + Router::new().route("/", post(cancel_handler)) +} diff --git a/host/src/server/api/v3/proof/mod.rs b/host/src/server/api/v3/proof/mod.rs new file mode 100644 index 000000000..2e739cc58 --- /dev/null +++ b/host/src/server/api/v3/proof/mod.rs @@ -0,0 +1,219 @@ +use axum::{debug_handler, extract::State, routing::post, Json, Router}; +use raiko_core::{ + interfaces::{AggregationOnlyRequest, AggregationRequest, ProofRequest, ProofRequestOpt}, + provider::get_task_data, +}; +use raiko_tasks::{TaskDescriptor, TaskManager, TaskStatus}; +use utoipa::OpenApi; + +use crate::{ + interfaces::HostResult, + metrics::{inc_current_req, inc_guest_req_count, inc_host_req_count}, + server::api::{v2, v3::Status}, + Message, ProverState, +}; +use tracing::{debug, info}; + +mod aggregate; +mod cancel; + +#[utoipa::path(post, path = "/proof", + tag = "Proving", + request_body = AggregationRequest, + responses ( + (status = 200, description = "Successfully submitted proof task, queried tasks in progress or retrieved proof.", body = Status) + ) +)] +#[debug_handler(state = ProverState)] +/// Submit a proof aggregation task with requested config, get task status or get proof value. +/// +/// Accepts a proof request and creates a proving task with the specified guest prover. +/// The guest provers currently available are: +/// - native - constructs a block and checks for equality +/// - sgx - uses the sgx environment to construct a block and produce proof of execution +/// - sp1 - uses the sp1 prover +/// - risc0 - uses the risc0 prover +async fn proof_handler( + State(prover_state): State, + Json(mut aggregation_request): Json, +) -> HostResult { + inc_current_req(); + // Override the existing proof request config from the config file and command line + // options with the request from the client. + aggregation_request.merge(&prover_state.request_config())?; + + let mut tasks = Vec::with_capacity(aggregation_request.block_numbers.len()); + + let proof_request_opts: Vec = aggregation_request.clone().into(); + + if proof_request_opts.is_empty() { + return Err(anyhow::anyhow!("No blocks for proving provided").into()); + } + + // Construct the actual proof request from the available configs. + for proof_request_opt in proof_request_opts { + let proof_request = ProofRequest::try_from(proof_request_opt)?; + + inc_host_req_count(proof_request.block_number); + inc_guest_req_count(&proof_request.proof_type, proof_request.block_number); + + let (chain_id, blockhash) = get_task_data( + &proof_request.network, + proof_request.block_number, + &prover_state.chain_specs, + ) + .await?; + + let key = TaskDescriptor::from(( + chain_id, + blockhash, + proof_request.proof_type, + proof_request.prover.to_string(), + )); + + tasks.push((key, proof_request)); + } + + let mut manager = prover_state.task_manager(); + + let mut is_registered = false; + let mut is_success = true; + let mut statuses = Vec::with_capacity(tasks.len()); + + for (key, req) in tasks.iter() { + let status = manager.get_task_proving_status(key).await?; + + let Some((latest_status, ..)) = status.last() else { + // If there are no tasks with provided config, create a new one. + manager.enqueue_task(key).await?; + + prover_state.task_channel.try_send(Message::from(req))?; + is_registered = true; + continue; + }; + + match latest_status { + // If task has been cancelled add it to the queue again + TaskStatus::Cancelled + | TaskStatus::Cancelled_Aborted + | TaskStatus::Cancelled_NeverStarted + | TaskStatus::CancellationInProgress => { + manager + .update_task_progress(key.clone(), TaskStatus::Registered, None) + .await?; + + prover_state.task_channel.try_send(Message::from(req))?; + + is_registered = true; + is_success = false; + } + // If the task has succeeded, return the proof. + TaskStatus::Success => {} + // For all other statuses just return the status. + status => { + statuses.push(status.clone()); + is_registered = false; + is_success = false; + } + } + } + + if is_registered { + Ok(TaskStatus::Registered.into()) + } else if is_success { + info!("All tasks are successful, aggregating proofs"); + let mut proofs = Vec::with_capacity(tasks.len()); + for (task, req) in tasks { + let raw_proof = manager.get_task_proof(&task).await?; + let proof = serde_json::from_slice(&raw_proof)?; + debug!("req: {req:?} gets proof: {proof:?}"); + proofs.push(proof); + } + + let aggregation_request = AggregationOnlyRequest { + proofs, + proof_type: aggregation_request.proof_type, + prover_args: aggregation_request.prover_args, + }; + + let status = manager + .get_aggregation_task_proving_status(&aggregation_request) + .await?; + + let Some((latest_status, ..)) = status.last() else { + // If there are no tasks with provided config, create a new one. + manager + .enqueue_aggregation_task(&aggregation_request) + .await?; + + prover_state + .task_channel + .try_send(Message::from(aggregation_request.clone()))?; + return Ok(Status::from(TaskStatus::Registered)); + }; + + match latest_status { + // If task has been cancelled add it to the queue again + TaskStatus::Cancelled + | TaskStatus::Cancelled_Aborted + | TaskStatus::Cancelled_NeverStarted + | TaskStatus::CancellationInProgress => { + manager + .update_aggregation_task_progress( + &aggregation_request, + TaskStatus::Registered, + None, + ) + .await?; + + prover_state + .task_channel + .try_send(Message::from(aggregation_request))?; + + Ok(Status::from(TaskStatus::Registered)) + } + // If the task has succeeded, return the proof. + TaskStatus::Success => { + let proof = manager + .get_aggregation_task_proof(&aggregation_request) + .await?; + + Ok(proof.into()) + } + // For all other statuses just return the status. + status => Ok(status.clone().into()), + } + } else { + let status = statuses.into_iter().collect::(); + Ok(status.into()) + } +} + +#[derive(OpenApi)] +#[openapi(paths(proof_handler))] +struct Docs; + +pub fn create_docs() -> utoipa::openapi::OpenApi { + [ + cancel::create_docs(), + aggregate::create_docs(), + v2::proof::report::create_docs(), + v2::proof::list::create_docs(), + v2::proof::prune::create_docs(), + ] + .into_iter() + .fold(Docs::openapi(), |mut docs, curr| { + docs.merge(curr); + docs + }) +} + +pub fn create_router() -> Router { + Router::new() + .route("/", post(proof_handler)) + .nest("/cancel", cancel::create_router()) + .nest("/aggregate", aggregate::create_router()) + .nest("/report", v2::proof::report::create_router()) + .nest("/list", v2::proof::list::create_router()) + .nest("/prune", v2::proof::prune::create_router()) +} diff --git a/lib/src/builder.rs b/lib/src/builder.rs index 3269f98bb..b60be8f00 100644 --- a/lib/src/builder.rs +++ b/lib/src/builder.rs @@ -160,7 +160,7 @@ impl + DatabaseCommit + OptimisticDatabase> } = executor .execute((&block, total_difficulty).into()) .map_err(|e| { - error!("Error executing block: {:?}", e); + error!("Error executing block: {e:?}"); e })?; // Filter out the valid transactions so that the header checks only take these into account @@ -294,8 +294,8 @@ impl RethBlockBuilder { state_trie.insert_rlp(&state_trie_index, state_account)?; } - debug!("Accounts touched {:?}", account_touched); - debug!("Storages touched {:?}", storage_touched); + debug!("Accounts touched {account_touched:?}"); + debug!("Storages touched {storage_touched:?}"); Ok(state_trie.hash()) } diff --git a/lib/src/input.rs b/lib/src/input.rs index 1b0688b16..bb9c9ed9b 100644 --- a/lib/src/input.rs +++ b/lib/src/input.rs @@ -12,7 +12,9 @@ use serde_with::serde_as; #[cfg(not(feature = "std"))] use crate::no_std::*; -use crate::{consts::ChainSpec, primitives::mpt::MptNode, utils::zlib_compress_data}; +use crate::{ + consts::ChainSpec, primitives::mpt::MptNode, prover::Proof, utils::zlib_compress_data, +}; /// Represents the state of an account's storage. /// The storage trie together with the used storage slots allow us to reconstruct all the @@ -41,6 +43,42 @@ pub struct GuestInput { pub taiko: TaikoGuestInput, } +/// External aggregation input. +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct AggregationGuestInput { + /// All block proofs to prove + pub proofs: Vec, +} + +/// The raw proof data necessary to verify a proof +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct RawProof { + /// The actual proof + pub proof: Vec, + /// The resulting hash + pub input: B256, +} + +/// External aggregation input. +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct RawAggregationGuestInput { + /// All block proofs to prove + pub proofs: Vec, +} + +/// External aggregation input. +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct AggregationGuestOutput { + /// The resulting hash + pub hash: B256, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ZkAggregationGuestInput { + pub image_id: [u32; 8], + pub block_inputs: Vec, +} + impl From<(Block, Header, ChainSpec, TaikoGuestInput)> for GuestInput { fn from( (block, parent_header, chain_spec, taiko): (Block, Header, ChainSpec, TaikoGuestInput), diff --git a/lib/src/protocol_instance.rs b/lib/src/protocol_instance.rs index 5036173f7..3f6271ef9 100644 --- a/lib/src/protocol_instance.rs +++ b/lib/src/protocol_instance.rs @@ -18,7 +18,7 @@ use crate::{ }, CycleTracker, }; -use log::info; +use log::{debug, info}; use reth_evm_ethereum::taiko::ANCHOR_GAS_LIMIT; #[derive(Debug, Clone)] @@ -275,6 +275,18 @@ impl ProtocolInstance { pub fn instance_hash(&self) -> B256 { // packages/protocol/contracts/verifiers/libs/LibPublicInput.sol // "VERIFY_PROOF", _chainId, _verifierContract, _tran, _newInstance, _prover, _metaHash + debug!( + "calculate instance_hash from: + chain_id: {:?}, verifier: {:?}, transition: {:?}, sgx_instance: {:?}, + prover: {:?}, block_meta: {:?}, meta_hash: {:?}", + self.chain_id, + self.verifier_address, + self.transition.clone(), + self.sgx_instance, + self.prover, + self.block_metadata, + self.meta_hash(), + ); let data = ( "VERIFY_PROOF", self.chain_id, @@ -315,6 +327,36 @@ fn bytes_to_bytes32(input: &[u8]) -> [u8; 32] { bytes } +pub fn words_to_bytes_le(words: &[u32; 8]) -> [u8; 32] { + let mut bytes = [0u8; 32]; + for i in 0..8 { + let word_bytes = words[i].to_le_bytes(); + bytes[i * 4..(i + 1) * 4].copy_from_slice(&word_bytes); + } + bytes +} + +pub fn words_to_bytes_be(words: &[u32; 8]) -> [u8; 32] { + let mut bytes = [0u8; 32]; + for i in 0..8 { + let word_bytes = words[i].to_be_bytes(); + bytes[i * 4..(i + 1) * 4].copy_from_slice(&word_bytes); + } + bytes +} + +pub fn aggregation_output_combine(public_inputs: Vec) -> Vec { + let mut output = Vec::with_capacity(public_inputs.len() * 32); + for public_input in public_inputs.iter() { + output.extend_from_slice(&public_input.0); + } + output +} + +pub fn aggregation_output(program: B256, public_inputs: Vec) -> Vec { + aggregation_output_combine([vec![program], public_inputs].concat()) +} + #[cfg(test)] mod tests { use alloy_primitives::{address, b256}; diff --git a/lib/src/prover.rs b/lib/src/prover.rs index 948f57af4..08de0229a 100644 --- a/lib/src/prover.rs +++ b/lib/src/prover.rs @@ -2,7 +2,7 @@ use reth_primitives::{ChainId, B256}; use serde::{Deserialize, Serialize}; use utoipa::ToSchema; -use crate::input::{GuestInput, GuestOutput}; +use crate::input::{AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput}; #[derive(thiserror::Error, Debug)] pub enum ProverError { @@ -26,13 +26,19 @@ pub type ProverResult = core::result::Result; pub type ProverConfig = serde_json::Value; pub type ProofKey = (ChainId, B256, u8); -#[derive(Debug, Serialize, ToSchema, Deserialize, Default)] +#[derive(Clone, Debug, Serialize, ToSchema, Deserialize, Default, PartialEq, Eq, Hash)] /// The response body of a proof request. pub struct Proof { /// The proof either TEE or ZK. pub proof: Option, + /// The public input + pub input: Option, /// The TEE quote. pub quote: Option, + /// The assumption UUID. + pub uuid: Option, + /// The kzg proof. + pub kzg_proof: Option, } #[async_trait::async_trait] @@ -56,5 +62,12 @@ pub trait Prover { store: Option<&mut dyn IdWrite>, ) -> ProverResult; + async fn aggregate( + input: AggregationGuestInput, + output: &AggregationGuestOutput, + config: &ProverConfig, + store: Option<&mut dyn IdWrite>, + ) -> ProverResult; + async fn cancel(proof_key: ProofKey, read: Box<&mut dyn IdStore>) -> ProverResult<()>; } diff --git a/pipeline/src/builder.rs b/pipeline/src/builder.rs index 7282cd858..9972c80de 100644 --- a/pipeline/src/builder.rs +++ b/pipeline/src/builder.rs @@ -140,7 +140,7 @@ impl CommandBuilder { println!("Using {tool}: {out}"); Some(PathBuf::from(out)) } else { - println!("Command succeeded with unknown output: {:?}", stdout); + println!("Command succeeded with unknown output: {stdout:?}"); None } } else { diff --git a/pipeline/src/executor.rs b/pipeline/src/executor.rs index a46128a09..5055018e3 100644 --- a/pipeline/src/executor.rs +++ b/pipeline/src/executor.rs @@ -100,7 +100,11 @@ impl Executor { let elf = std::fs::read(&dest.join(&name.replace('_', "-")))?; let prover = CpuProver::new(); let key_pair = prover.setup(&elf); - println!("sp1 elf vk is: {}", key_pair.1.bytes32()); + println!("sp1 elf vk bn256 is: {}", key_pair.1.bytes32()); + println!( + "sp1 elf vk hash_bytes is: {}", + hex::encode(key_pair.1.hash_bytes()) + ); } Ok(()) diff --git a/provers/risc0/builder/src/main.rs b/provers/risc0/builder/src/main.rs index b0de9edb1..523824f40 100644 --- a/provers/risc0/builder/src/main.rs +++ b/provers/risc0/builder/src/main.rs @@ -5,7 +5,10 @@ use std::path::PathBuf; fn main() { let pipeline = Risc0Pipeline::new("provers/risc0/guest", "release"); - pipeline.bins(&["risc0-guest"], "provers/risc0/driver/src/methods"); + pipeline.bins( + &["risc0-guest", "risc0-aggregation"], + "provers/risc0/driver/src/methods", + ); #[cfg(feature = "test")] pipeline.tests(&["risc0-guest"], "provers/risc0/driver/src/methods"); #[cfg(feature = "bench")] diff --git a/provers/risc0/driver/Cargo.toml b/provers/risc0/driver/Cargo.toml index a1f5e11e7..3274acce2 100644 --- a/provers/risc0/driver/Cargo.toml +++ b/provers/risc0/driver/Cargo.toml @@ -63,9 +63,9 @@ enable = [ "serde_json", "hex", "reqwest", - "lazy_static" + "lazy_static", ] cuda = ["risc0-zkvm?/cuda"] metal = ["risc0-zkvm?/metal"] bench = [] -bonsai-auto-scaling = [] \ No newline at end of file +bonsai-auto-scaling = [] diff --git a/provers/risc0/driver/src/bonsai.rs b/provers/risc0/driver/src/bonsai.rs index 2129799d8..0c8d8565f 100644 --- a/provers/risc0/driver/src/bonsai.rs +++ b/provers/risc0/driver/src/bonsai.rs @@ -1,15 +1,16 @@ use crate::{ methods::risc0_guest::RISC0_GUEST_ID, - snarks::{stark2snark, verify_groth16_snark}, + snarks::{stark2snark, verify_groth16_from_snark_receipt}, Risc0Response, }; +use alloy_primitives::B256; use log::{debug, error, info, warn}; use raiko_lib::{ primitives::keccak::keccak, prover::{IdWrite, ProofKey, ProverError, ProverResult}, }; use risc0_zkvm::{ - compute_image_id, is_dev_mode, serde::to_vec, sha::Digest, Assumption, ExecutorEnv, + compute_image_id, is_dev_mode, serde::to_vec, sha::Digest, AssumptionReceipt, ExecutorEnv, ExecutorImpl, Receipt, }; use serde::{de::DeserializeOwned, Serialize}; @@ -106,10 +107,9 @@ pub async fn verify_bonsai_receipt( let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?; let bonsai_err_log = session.logs(&client); return Err(BonsaiExecutionError::Fatal(format!( - "Workflow exited: {} - | err: {} | log: {:?}", + "Workflow exited: {} - | err: {} | log: {bonsai_err_log:?}", res.status, res.error_msg.unwrap_or_default(), - bonsai_err_log ))); } } @@ -120,7 +120,7 @@ pub async fn maybe_prove, elf: &[u8], expected_output: &O, - assumptions: (Vec, Vec), + assumptions: (Vec>, Vec), proof_key: ProofKey, id_store: &mut Option<&mut dyn IdWrite>, ) -> Option<(String, Receipt)> { @@ -283,20 +283,27 @@ pub async fn prove_bonsai( pub async fn bonsai_stark_to_snark( stark_uuid: String, stark_receipt: Receipt, + input: B256, ) -> ProverResult { let image_id = Digest::from(RISC0_GUEST_ID); - let (snark_uuid, snark_receipt) = stark2snark(image_id, stark_uuid, stark_receipt) - .await - .map_err(|err| format!("Failed to convert STARK to SNARK: {err:?}"))?; + let (snark_uuid, snark_receipt) = + stark2snark(image_id, stark_uuid.clone(), stark_receipt.clone()) + .await + .map_err(|err| format!("Failed to convert STARK to SNARK: {err:?}"))?; info!("Validating SNARK uuid: {snark_uuid}"); - let enc_proof = verify_groth16_snark(image_id, snark_receipt) + let enc_proof = verify_groth16_from_snark_receipt(image_id, snark_receipt) .await .map_err(|err| format!("Failed to verify SNARK: {err:?}"))?; let snark_proof = format!("0x{}", hex::encode(enc_proof)); - Ok(Risc0Response { proof: snark_proof }) + Ok(Risc0Response { + proof: snark_proof, + receipt: serde_json::to_string(&stark_receipt).unwrap(), + uuid: stark_uuid, + input, + }) } /// Prove the given ELF locally with the given input and assumptions. The segments are @@ -305,7 +312,7 @@ pub fn prove_locally( segment_limit_po2: u32, encoded_input: Vec, elf: &[u8], - assumptions: Vec, + assumptions: Vec>, profile: bool, ) -> ProverResult { debug!("Proving with segment_limit_po2 = {segment_limit_po2:?}"); diff --git a/provers/risc0/driver/src/lib.rs b/provers/risc0/driver/src/lib.rs index 177ba6742..6dd8a200c 100644 --- a/provers/risc0/driver/src/lib.rs +++ b/provers/risc0/driver/src/lib.rs @@ -2,15 +2,24 @@ #[cfg(feature = "bonsai-auto-scaling")] use crate::bonsai::auto_scaling::shutdown_bonsai; -use crate::methods::risc0_guest::RISC0_GUEST_ELF; +use crate::{ + methods::risc0_aggregation::RISC0_AGGREGATION_ELF, + methods::risc0_guest::{RISC0_GUEST_ELF, RISC0_GUEST_ID}, +}; use alloy_primitives::{hex::ToHexExt, B256}; -pub use bonsai::*; -use log::warn; +use bonsai::{cancel_proof, maybe_prove}; +use log::{info, warn}; use raiko_lib::{ - input::{GuestInput, GuestOutput}, + input::{ + AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput, + ZkAggregationGuestInput, + }, prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, ProverResult}, }; -use risc0_zkvm::serde::to_vec; +use risc0_zkvm::{ + compute_image_id, default_prover, serde::to_vec, sha::Digestible, ExecutorEnv, ProverOpts, + Receipt, +}; use serde::{Deserialize, Serialize}; use serde_with::serde_as; use std::fmt::Debug; @@ -32,13 +41,19 @@ pub struct Risc0Param { #[derive(Clone, Serialize, Deserialize)] pub struct Risc0Response { pub proof: String, + pub receipt: String, + pub uuid: String, + pub input: B256, } impl From for Proof { fn from(value: Risc0Response) -> Self { Self { proof: Some(value.proof), - quote: None, + quote: Some(value.receipt), + input: Some(value.input), + uuid: Some(value.uuid), + kzg_proof: None, } } } @@ -70,25 +85,30 @@ impl Prover for Risc0Prover { encoded_input, RISC0_GUEST_ELF, &output.hash, - Default::default(), + (Vec::::new(), Vec::new()), proof_key, &mut id_store, ) .await; + let receipt = result.clone().unwrap().1.clone(); + let uuid = result.clone().unwrap().0; + let proof_gen_result = if result.is_some() { if config.snark && config.bonsai { let (stark_uuid, stark_receipt) = result.clone().unwrap(); - bonsai::bonsai_stark_to_snark(stark_uuid, stark_receipt) + bonsai::bonsai_stark_to_snark(stark_uuid, stark_receipt, output.hash) .await .map(|r0_response| r0_response.into()) .map_err(|e| ProverError::GuestError(e.to_string())) } else { warn!("proof is not in snark mode, please check."); let (_, stark_receipt) = result.clone().unwrap(); - Ok(Risc0Response { proof: stark_receipt.journal.encode_hex_with_prefix(), + receipt: serde_json::to_string(&receipt).unwrap(), + uuid, + input: output.hash, } .into()) } @@ -109,6 +129,83 @@ impl Prover for Risc0Prover { proof_gen_result } + async fn aggregate( + input: AggregationGuestInput, + _output: &AggregationGuestOutput, + config: &ProverConfig, + _id_store: Option<&mut dyn IdWrite>, + ) -> ProverResult { + let config = Risc0Param::deserialize(config.get("risc0").unwrap()).unwrap(); + assert!( + config.snark && config.bonsai, + "Aggregation must be in bonsai snark mode" + ); + + // Extract the block proof receipts + let assumptions: Vec = input + .proofs + .iter() + .map(|proof| { + let receipt: Receipt = serde_json::from_str(&proof.quote.clone().unwrap()) + .expect("Failed to deserialize"); + receipt + }) + .collect::>(); + let block_inputs: Vec = input + .proofs + .iter() + .map(|proof| proof.input.unwrap()) + .collect::>(); + let input = ZkAggregationGuestInput { + image_id: RISC0_GUEST_ID, + block_inputs, + }; + info!("Start aggregate proofs"); + // add_assumption makes the receipt to be verified available to the prover. + let env = { + let mut env = ExecutorEnv::builder(); + for assumption in assumptions { + env.add_assumption(assumption); + } + env.write(&input).unwrap().build().unwrap() + }; + + let opts = ProverOpts::groth16(); + let receipt = default_prover() + .prove_with_opts(env, RISC0_AGGREGATION_ELF, &opts) + .unwrap() + .receipt; + + info!( + "Generate aggregatino receipt journal: {:?}", + receipt.journal + ); + let aggregation_image_id = compute_image_id(RISC0_AGGREGATION_ELF).unwrap(); + let enc_proof = + snarks::verify_groth16_snark_from_receipt(aggregation_image_id, receipt.clone()) + .await + .map_err(|err| format!("Failed to verify SNARK: {err:?}"))?; + let snark_proof = format!("0x{}", hex::encode(enc_proof)); + + let proof_gen_result = Ok(Risc0Response { + proof: snark_proof, + receipt: serde_json::to_string(&receipt).unwrap(), + uuid: "".to_owned(), + input: B256::from_slice(&receipt.journal.digest().as_bytes()), + } + .into()); + + #[cfg(feature = "bonsai-auto-scaling")] + if config.bonsai { + // shutdown bonsai + shutdown_bonsai() + .await + .map_err(|e| ProverError::GuestError(e.to_string()))?; + } + + proof_gen_result + } + async fn cancel(key: ProofKey, id_store: Box<&mut dyn IdStore>) -> ProverResult<()> { let uuid = match id_store.read_id(key).await { Ok(uuid) => uuid, diff --git a/provers/risc0/driver/src/methods/mod.rs b/provers/risc0/driver/src/methods/mod.rs index 0211d22de..19219d8af 100644 --- a/provers/risc0/driver/src/methods/mod.rs +++ b/provers/risc0/driver/src/methods/mod.rs @@ -1,3 +1,4 @@ +pub mod risc0_aggregation; pub mod risc0_guest; // To build the following `$ cargo run --features test,bench --bin risc0-builder` diff --git a/provers/risc0/driver/src/methods/risc0_aggregation.rs b/provers/risc0/driver/src/methods/risc0_aggregation.rs new file mode 100644 index 000000000..06ad39e27 --- /dev/null +++ b/provers/risc0/driver/src/methods/risc0_aggregation.rs @@ -0,0 +1,5 @@ +pub const RISC0_AGGREGATION_ELF: &[u8] = + include_bytes!("../../../guest/target/riscv32im-risc0-zkvm-elf/release/risc0-aggregation"); +pub const RISC0_AGGREGATION_ID: [u32; 8] = [ + 440526723, 3767976668, 67051936, 881100330, 2605787818, 1152192925, 943988177, 1141581874, +]; diff --git a/provers/risc0/driver/src/methods/risc0_guest.rs b/provers/risc0/driver/src/methods/risc0_guest.rs index 19d5fdfdc..159152655 100644 --- a/provers/risc0/driver/src/methods/risc0_guest.rs +++ b/provers/risc0/driver/src/methods/risc0_guest.rs @@ -1,5 +1,5 @@ pub const RISC0_GUEST_ELF: &[u8] = include_bytes!("../../../guest/target/riscv32im-risc0-zkvm-elf/release/risc0-guest"); pub const RISC0_GUEST_ID: [u32; 8] = [ - 2724640415, 1388818056, 2370444677, 1329173777, 2657825669, 1524407056, 1629931902, 314750851, + 2426111784, 2252773481, 4093155148, 2853313326, 836865213, 1159934005, 790932950, 229907112, ]; diff --git a/provers/risc0/driver/src/snarks.rs b/provers/risc0/driver/src/snarks.rs index 5cc00d232..a766ccf7a 100644 --- a/provers/risc0/driver/src/snarks.rs +++ b/provers/risc0/driver/src/snarks.rs @@ -30,7 +30,7 @@ use risc0_zkvm::{ use tracing::{error as tracing_err, info as tracing_info}; -use crate::save_receipt; +use crate::bonsai::save_receipt; sol!( /// A Groth16 seal over the claimed receipt claim. @@ -150,9 +150,31 @@ pub async fn stark2snark( Ok(snark_data) } -pub async fn verify_groth16_snark( +pub async fn verify_groth16_from_snark_receipt( image_id: Digest, snark_receipt: SnarkReceipt, +) -> Result> { + let seal = encode(snark_receipt.snark.to_vec())?; + let journal_digest = snark_receipt.journal.digest(); + let post_state_digest = snark_receipt.post_state_digest.digest(); + verify_groth16_snark_impl(image_id, seal, journal_digest, post_state_digest).await +} + +pub async fn verify_groth16_snark_from_receipt( + image_id: Digest, + receipt: Receipt, +) -> Result> { + let seal = receipt.inner.groth16().unwrap().seal.clone(); + let journal_digest = receipt.journal.digest(); + let post_state_digest = receipt.claim()?.as_value().unwrap().post.digest(); + verify_groth16_snark_impl(image_id, seal, journal_digest, post_state_digest).await +} + +pub async fn verify_groth16_snark_impl( + image_id: Digest, + seal: Vec, + journal_digest: Digest, + post_state_digest: Digest, ) -> Result> { let verifier_rpc_url = std::env::var("GROTH16_VERIFIER_RPC_URL").expect("env GROTH16_VERIFIER_RPC_URL"); @@ -167,19 +189,15 @@ pub async fn verify_groth16_snark( 500, )?); - let seal = encode(snark_receipt.snark.to_vec())?; - let journal_digest = snark_receipt.journal.digest(); + let enc_seal = encode(seal)?; tracing_info!("Verifying SNARK:"); - tracing_info!("Seal: {}", hex::encode(&seal)); + tracing_info!("Seal: {}", hex::encode(&enc_seal)); tracing_info!("Image ID: {}", hex::encode(image_id.as_bytes())); - tracing_info!( - "Post State Digest: {}", - hex::encode(&snark_receipt.post_state_digest) - ); + tracing_info!("Post State Digest: {}", hex::encode(&post_state_digest)); tracing_info!("Journal Digest: {}", hex::encode(journal_digest)); let verify_call_res = IRiscZeroVerifier::new(groth16_verifier_addr, http_client) .verify( - seal.clone().into(), + enc_seal.clone().into(), image_id.as_bytes().try_into().unwrap(), journal_digest.into(), ) @@ -188,13 +206,17 @@ pub async fn verify_groth16_snark( if verify_call_res.is_ok() { tracing_info!("SNARK verified successfully using {groth16_verifier_addr:?}!"); } else { - tracing_err!("SNARK verification failed: {:?}!", verify_call_res); + tracing_err!("SNARK verification failed: {verify_call_res:?}!"); } - Ok((seal, B256::from_slice(image_id.as_bytes())) + Ok(make_risc0_groth16_proof(enc_seal, image_id)) +} + +pub fn make_risc0_groth16_proof(seal: Vec, image_id: Digest) -> Vec { + (seal, B256::from_slice(image_id.as_bytes())) .abi_encode() .iter() .skip(32) .copied() - .collect()) + .collect() } diff --git a/provers/risc0/guest/Cargo.toml b/provers/risc0/guest/Cargo.toml index 28091f3c9..190ac9a60 100644 --- a/provers/risc0/guest/Cargo.toml +++ b/provers/risc0/guest/Cargo.toml @@ -9,6 +9,10 @@ edition = "2021" name = "zk_op" path = "src/zk_op.rs" +[[bin]] +name = "risc0-aggregation" +path = "src/aggregation.rs" + [[bin]] name = "sha256" path = "src/benchmark/sha256.rs" diff --git a/provers/risc0/guest/src/aggregation.rs b/provers/risc0/guest/src/aggregation.rs new file mode 100644 index 000000000..240711d7d --- /dev/null +++ b/provers/risc0/guest/src/aggregation.rs @@ -0,0 +1,27 @@ +//! Aggregates multiple block proofs +#![no_main] +harness::entrypoint!(main); + +use risc0_zkvm::{guest::env, serde}; + +use raiko_lib::{ + input::ZkAggregationGuestInput, + primitives::B256, + protocol_instance::{aggregation_output, words_to_bytes_le}, +}; + +pub fn main() { + // Read the aggregation input + let input = env::read::(); + + // Verify the proofs. + for block_input in input.block_inputs.iter() { + env::verify(input.image_id, &serde::to_vec(block_input).unwrap()).unwrap(); + } + + // The aggregation output + env::commit_slice(&aggregation_output( + B256::from(words_to_bytes_le(&input.image_id)), + input.block_inputs, + )); +} diff --git a/provers/sgx/guest/src/app_args.rs b/provers/sgx/guest/src/app_args.rs index 35020f272..10f8ca18e 100644 --- a/provers/sgx/guest/src/app_args.rs +++ b/provers/sgx/guest/src/app_args.rs @@ -17,6 +17,8 @@ pub struct App { pub enum Command { /// Prove (i.e. sign) a single block and exit. OneShot(OneShotArgs), + /// Aggregate proofs + Aggregate(OneShotArgs), /// Bootstrap the application and then exit. The bootstrapping process generates the /// initial public-private key pair and stores it on the disk in an encrypted /// format using SGX encryption primitives. diff --git a/provers/sgx/guest/src/main.rs b/provers/sgx/guest/src/main.rs index accd54913..c7af5db30 100644 --- a/provers/sgx/guest/src/main.rs +++ b/provers/sgx/guest/src/main.rs @@ -3,6 +3,7 @@ extern crate secp256k1; use anyhow::{anyhow, Result}; use clap::Parser; +use one_shot::aggregate; use crate::{ app_args::{App, Command}, @@ -22,6 +23,10 @@ pub async fn main() -> Result<()> { println!("Starting one shot mode"); one_shot(args.global_opts, one_shot_args).await? } + Command::Aggregate(one_shot_args) => { + println!("Starting one shot mode"); + aggregate(args.global_opts, one_shot_args).await? + } Command::Bootstrap => { println!("Bootstrapping the app"); bootstrap(args.global_opts)? diff --git a/provers/sgx/guest/src/one_shot.rs b/provers/sgx/guest/src/one_shot.rs index 4c4cfee71..156f92f9a 100644 --- a/provers/sgx/guest/src/one_shot.rs +++ b/provers/sgx/guest/src/one_shot.rs @@ -8,8 +8,11 @@ use std::{ use anyhow::{anyhow, bail, Context, Error, Result}; use base64_serde::base64_serde_type; use raiko_lib::{ - builder::calculate_block_header, consts::VerifierType, input::GuestInput, primitives::Address, - protocol_instance::ProtocolInstance, + builder::calculate_block_header, + consts::VerifierType, + input::{GuestInput, RawAggregationGuestInput}, + primitives::{keccak, Address, B256}, + protocol_instance::{aggregation_output_combine, ProtocolInstance}, }; use secp256k1::{Keypair, SecretKey}; use serde::Serialize; @@ -143,6 +146,7 @@ pub async fn one_shot(global_opts: GlobalOpts, args: OneShotArgs) -> Result<()> let sig = sign_message(&prev_privkey, pi_hash)?; // Create the proof for the onchain SGX verifier + // 4(id) + 20(new) + 65(sig) = 89 const SGX_PROOF_LEN: usize = 89; let mut proof = Vec::with_capacity(SGX_PROOF_LEN); proof.extend(args.sgx_instance_id.to_be_bytes()); @@ -160,6 +164,86 @@ pub async fn one_shot(global_opts: GlobalOpts, args: OneShotArgs) -> Result<()> "quote": hex::encode(quote), "public_key": format!("0x{new_pubkey}"), "instance_address": new_instance.to_string(), + "input": pi_hash.to_string(), + }); + println!("{data}"); + + // Print out general SGX information + print_sgx_info() +} + +pub async fn aggregate(global_opts: GlobalOpts, args: OneShotArgs) -> Result<()> { + // Make sure this SGX instance was bootstrapped + let prev_privkey = load_bootstrap(&global_opts.secrets_dir) + .or_else(|_| bail!("Application was not bootstrapped or has a deprecated bootstrap.")) + .unwrap(); + + println!("Global options: {global_opts:?}, OneShot options: {args:?}"); + + let new_pubkey = public_key(&prev_privkey); + let new_instance = public_key_to_address(&new_pubkey); + + let input: RawAggregationGuestInput = + bincode::deserialize_from(std::io::stdin()).expect("unable to deserialize input"); + + // Make sure the chain of old/new public keys is preserved + let old_instance = Address::from_slice(&input.proofs[0].proof.clone()[4..24]); + let mut cur_instance = old_instance; + + // Verify the proofs + for proof in input.proofs.iter() { + // TODO: verify protocol instance data so we can trust the old/new instance data + assert_eq!( + recover_signer_unchecked(&proof.proof.clone()[24..].try_into().unwrap(), &proof.input,) + .unwrap(), + cur_instance, + ); + cur_instance = Address::from_slice(&proof.proof.clone()[4..24]); + } + + // Current public key needs to match latest proof new public key + assert_eq!(cur_instance, new_instance); + + // Calculate the aggregation hash + let aggregation_hash = keccak::keccak(aggregation_output_combine( + [ + vec![ + B256::left_padding_from(old_instance.as_ref()), + B256::left_padding_from(new_instance.as_ref()), + ], + input + .proofs + .iter() + .map(|proof| proof.input) + .collect::>(), + ] + .concat(), + )); + + // Sign the public aggregation hash + let sig = sign_message(&prev_privkey, aggregation_hash.into())?; + + // Create the proof for the onchain SGX verifier + const SGX_PROOF_LEN: usize = 109; + // 4(id) + 20(old) + 20(new) + 65(sig) = 109 + let mut proof = Vec::with_capacity(SGX_PROOF_LEN); + proof.extend(args.sgx_instance_id.to_be_bytes()); + proof.extend(old_instance); + proof.extend(new_instance); + proof.extend(sig); + let proof = hex::encode(proof); + + // Store the public key address in the attestation data + save_attestation_user_report_data(new_instance)?; + + // Print out the proof and updated public info + let quote = get_sgx_quote()?; + let data = serde_json::json!({ + "proof": format!("0x{proof}"), + "quote": hex::encode(quote), + "public_key": format!("0x{new_pubkey}"), + "instance_address": new_instance.to_string(), + "input": B256::from(aggregation_hash).to_string(), }); println!("{data}"); diff --git a/provers/sgx/prover/Cargo.toml b/provers/sgx/prover/Cargo.toml index 69c0c3570..0c5f5a6c9 100644 --- a/provers/sgx/prover/Cargo.toml +++ b/provers/sgx/prover/Cargo.toml @@ -24,6 +24,7 @@ alloy-transport-http = { workspace = true } pem = { version = "3.0.4", optional = true } url = { workspace = true } anyhow = { workspace = true } +hex = { workspace = true } [features] default = ["dep:pem"] diff --git a/provers/sgx/prover/src/lib.rs b/provers/sgx/prover/src/lib.rs index 7f7688ac7..a74ee0e06 100644 --- a/provers/sgx/prover/src/lib.rs +++ b/provers/sgx/prover/src/lib.rs @@ -5,12 +5,16 @@ use std::{ fs::{copy, create_dir_all, remove_file}, path::{Path, PathBuf}, process::{Command as StdCommand, Output, Stdio}, - str, + str::{self, FromStr}, }; use once_cell::sync::Lazy; use raiko_lib::{ - input::{GuestInput, GuestOutput}, + input::{ + AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput, + RawAggregationGuestInput, RawProof, + }, + primitives::B256, prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, ProverResult}, }; use serde::{Deserialize, Serialize}; @@ -42,13 +46,17 @@ pub struct SgxResponse { /// proof format: 4b(id)+20b(pubkey)+65b(signature) pub proof: String, pub quote: String, + pub input: B256, } impl From for Proof { fn from(value: SgxResponse) -> Self { Self { proof: Some(value.proof), + input: Some(value.input), quote: Some(value.quote), + uuid: None, + kzg_proof: None, } } } @@ -147,6 +155,87 @@ impl Prover for SgxProver { sgx_proof.map(|r| r.into()) } + async fn aggregate( + input: AggregationGuestInput, + _output: &AggregationGuestOutput, + config: &ProverConfig, + _id_store: Option<&mut dyn IdWrite>, + ) -> ProverResult { + let sgx_param = SgxParam::deserialize(config.get("sgx").unwrap()).unwrap(); + + // Support both SGX and the direct backend for testing + let direct_mode = match env::var("SGX_DIRECT") { + Ok(value) => value == "1", + Err(_) => false, + }; + + println!( + "WARNING: running SGX in {} mode!", + if direct_mode { + "direct (a.k.a. simulation)" + } else { + "hardware" + } + ); + + // The working directory + let mut cur_dir = env::current_exe() + .expect("Fail to get current directory") + .parent() + .unwrap() + .to_path_buf(); + + // When running in tests we might be in a child folder + if cur_dir.ends_with("deps") { + cur_dir = cur_dir.parent().unwrap().to_path_buf(); + } + + println!("Current directory: {cur_dir:?}\n"); + // Working paths + PRIVATE_KEY + .get_or_init(|| async { cur_dir.join("secrets").join(PRIV_KEY_FILENAME) }) + .await; + GRAMINE_MANIFEST_TEMPLATE + .get_or_init(|| async { + cur_dir + .join(CONFIG) + .join("sgx-guest.local.manifest.template") + }) + .await; + + // The gramine command (gramine or gramine-direct for testing in non-SGX environment) + let gramine_cmd = || -> StdCommand { + let mut cmd = if direct_mode { + StdCommand::new("gramine-direct") + } else { + let mut cmd = StdCommand::new("sudo"); + cmd.arg("gramine-sgx"); + cmd + }; + cmd.current_dir(&cur_dir).arg(ELF_NAME); + cmd + }; + + // Setup: run this once while setting up your SGX instance + if sgx_param.setup { + setup(&cur_dir, direct_mode).await?; + } + + let mut sgx_proof = if sgx_param.bootstrap { + bootstrap(cur_dir.clone().join("secrets"), gramine_cmd()).await + } else { + // Dummy proof: it's ok when only setup/bootstrap was requested + Ok(SgxResponse::default()) + }; + + if sgx_param.prove { + // overwrite sgx_proof as the bootstrap quote stays the same in bootstrap & prove. + sgx_proof = aggregate(gramine_cmd(), input.clone(), sgx_param.instance_id).await + } + + sgx_proof.map(|r| r.into()) + } + async fn cancel(_proof_key: ProofKey, _read: Box<&mut dyn IdStore>) -> ProverResult<()> { Ok(()) } @@ -303,6 +392,54 @@ async fn prove( .map_err(|e| ProverError::GuestError(e.to_string()))? } +async fn aggregate( + mut gramine_cmd: StdCommand, + input: AggregationGuestInput, + instance_id: u64, +) -> ProverResult { + // Extract the useful parts of the proof here so the guest doesn't have to do it + let raw_input = RawAggregationGuestInput { + proofs: input + .proofs + .iter() + .map(|proof| RawProof { + input: proof.clone().input.unwrap(), + proof: hex::decode(&proof.clone().proof.unwrap()[2..]).unwrap(), + }) + .collect(), + }; + + tokio::task::spawn_blocking(move || { + let mut child = gramine_cmd + .arg("aggregate") + .arg("--sgx-instance-id") + .arg(instance_id.to_string()) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .map_err(|e| format!("Could not spawn gramine cmd: {e}"))?; + let stdin = child.stdin.as_mut().expect("Failed to open stdin"); + let input_success = bincode::serialize_into(stdin, &raw_input); + let output_success = child.wait_with_output(); + + match (input_success, output_success) { + (Ok(_), Ok(output)) => { + handle_output(&output, "SGX prove")?; + Ok(parse_sgx_result(output.stdout)?) + } + (Err(i), output_success) => Err(ProverError::GuestError(format!( + "Can not serialize input for SGX {i}, output is {output_success:?}" + ))), + (Ok(_), Err(output_err)) => Err(ProverError::GuestError( + handle_gramine_error("Could not run SGX guest prover", output_err).to_string(), + )), + } + }) + .await + .map_err(|e| ProverError::GuestError(e.to_string()))? +} + fn parse_sgx_result(output: Vec) -> ProverResult { let mut json_value: Option = None; let output = String::from_utf8(output).map_err(|e| e.to_string())?; @@ -324,6 +461,7 @@ fn parse_sgx_result(output: Vec) -> ProverResult { Ok(SgxResponse { proof: extract_field("proof"), quote: extract_field("quote"), + input: B256::from_str(&extract_field("input")).unwrap(), }) } diff --git a/provers/sp1/builder/src/main.rs b/provers/sp1/builder/src/main.rs index 7db899a13..fe696594e 100644 --- a/provers/sp1/builder/src/main.rs +++ b/provers/sp1/builder/src/main.rs @@ -5,7 +5,7 @@ use std::path::PathBuf; fn main() { let pipeline = Sp1Pipeline::new("provers/sp1/guest", "release"); - pipeline.bins(&["sp1-guest"], "provers/sp1/guest/elf"); + pipeline.bins(&["sp1-guest", "sp1-aggregation"], "provers/sp1/guest/elf"); #[cfg(feature = "test")] pipeline.tests(&["sp1-guest"], "provers/sp1/guest/elf"); #[cfg(feature = "bench")] diff --git a/provers/sp1/driver/src/lib.rs b/provers/sp1/driver/src/lib.rs index c0c3f60d1..8de517f52 100644 --- a/provers/sp1/driver/src/lib.rs +++ b/provers/sp1/driver/src/lib.rs @@ -1,8 +1,12 @@ #![cfg(feature = "enable")] #![feature(iter_advance_by)] +use once_cell::sync::Lazy; use raiko_lib::{ - input::{GuestInput, GuestOutput}, + input::{ + AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput, + ZkAggregationGuestInput, + }, prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, ProverResult}, Measurement, }; @@ -13,16 +17,27 @@ use sp1_sdk::{ action, network::client::NetworkClient, proto::network::{ProofMode, UnclaimReason}, + SP1Proof, SP1ProofWithPublicValues, SP1VerifyingKey, }; use sp1_sdk::{HashableKey, ProverClient, SP1Stdin}; -use std::{borrow::BorrowMut, env}; -use tracing::info; +use std::{ + borrow::BorrowMut, + env, fs, + path::{Path, PathBuf}, +}; +use tracing::{debug, error, info}; mod proof_verify; use proof_verify::remote_contract_verify::verify_sol_by_contract_call; pub const ELF: &[u8] = include_bytes!("../../guest/elf/sp1-guest"); +pub const AGGREGATION_ELF: &[u8] = include_bytes!("../../guest/elf/sp1-aggregation"); const SP1_PROVER_CODE: u8 = 1; +static FIXTURE_PATH: Lazy = + Lazy::new(|| Path::new(env!("CARGO_MANIFEST_DIR")).join("../contracts/src/fixtures/")); +static CONTRACT_PATH: Lazy = + Lazy::new(|| Path::new(env!("CARGO_MANIFEST_DIR")).join("../contracts/src/exports/")); +pub static VERIFIER: Lazy> = Lazy::new(init_verifier); #[serde_as] #[derive(Clone, Debug, Serialize, Deserialize)] @@ -67,15 +82,27 @@ pub enum ProverMode { impl From for Proof { fn from(value: Sp1Response) -> Self { Self { - proof: Some(value.proof), - quote: None, + proof: value.proof, + quote: value + .sp1_proof + .as_ref() + .map(|p| serde_json::to_string(&p.proof).unwrap()), + input: value + .sp1_proof + .as_ref() + .map(|p| B256::from_slice(p.public_values.as_slice())), + uuid: value.vkey.map(|v| serde_json::to_string(&v).unwrap()), + kzg_proof: None, } } } #[derive(Clone, Serialize, Deserialize)] pub struct Sp1Response { - pub proof: String, + pub proof: Option, + /// for aggregation + pub sp1_proof: Option, + pub vkey: Option, } pub struct Sp1Prover; @@ -90,6 +117,8 @@ impl Prover for Sp1Prover { let param = Sp1Param::deserialize(config.get("sp1").unwrap()).unwrap(); let mode = param.prover.clone().unwrap_or_else(get_env_mock); + println!("param: {param:?}"); + let mut stdin = SP1Stdin::new(); stdin.write(&input); @@ -118,8 +147,7 @@ impl Prover for Sp1Prover { RecursionMode::Compressed => prove_action.compressed().run(), RecursionMode::Plonk => prove_action.plonk().run(), } - .map_err(|e| ProverError::GuestError(format!("Sp1: local proving failed: {}", e))) - .unwrap() + .map_err(|e| ProverError::GuestError(format!("Sp1: local proving failed: {e}")))? } else { let network_prover = sp1_sdk::NetworkProver::new(); @@ -138,17 +166,22 @@ impl Prover for Sp1Prover { .await?; } info!( - "Sp1 Prover: block {:?} - proof id {:?}", - output.header.number, proof_id + "Sp1 Prover: block {:?} - proof id {proof_id:?}", + output.header.number ); network_prover .wait_proof::(&proof_id, None) .await - .map_err(|e| ProverError::GuestError(format!("Sp1: network proof failed {:?}", e))) - .unwrap() + .map_err(|e| ProverError::GuestError(format!("Sp1: network proof failed {e:?}")))? }; - let proof_bytes = prove_result.bytes(); + let proof_bytes = match param.recursion { + RecursionMode::Compressed => { + info!("Compressed proof is used in aggregation mode only"); + vec![] + } + _ => prove_result.bytes(), + }; if param.verify { let time = Measurement::start("verify", false); let pi_hash = prove_result @@ -158,34 +191,36 @@ impl Prover for Sp1Prover { .read::<[u8; 32]>(); let fixture = RaikoProofFixture { vkey: vk.bytes32(), - public_values: pi_hash.into(), - proof: proof_bytes.clone(), + public_values: B256::from_slice(&pi_hash).to_string(), + proof: reth_primitives::hex::encode_prefixed(&proof_bytes), }; verify_sol_by_contract_call(&fixture).await?; time.stop_with("==> Verification complete"); } - let proof_string = if proof_bytes.is_empty() { - None - } else { + let proof_string = (!proof_bytes.is_empty()).then_some( // 0x + 64 bytes of the vkey + the proof // vkey itself contains 0x prefix - Some(format!( + format!( "{}{}", vk.bytes32(), reth_primitives::hex::encode(proof_bytes) - )) - }; + ), + ); info!( - "Sp1 Prover: block {:?} completed! proof: {:?}", - output.header.number, proof_string + "Sp1 Prover: block {:?} completed! proof: {proof_string:?}", + output.header.number, ); - Ok::<_, ProverError>(Proof { - proof: proof_string, - quote: None, - }) + Ok::<_, ProverError>( + Sp1Response { + proof: proof_string, + sp1_proof: Some(prove_result), + vkey: Some(vk), + } + .into(), + ) } async fn cancel(key: ProofKey, id_store: Box<&mut dyn IdStore>) -> ProverResult<()> { @@ -210,6 +245,110 @@ impl Prover for Sp1Prover { id_store.remove_id(key).await?; Ok(()) } + + async fn aggregate( + input: AggregationGuestInput, + _output: &AggregationGuestOutput, + config: &ProverConfig, + _store: Option<&mut dyn IdWrite>, + ) -> ProverResult { + let param = Sp1Param::deserialize(config.get("sp1").unwrap()).unwrap(); + + info!("aggregate proof with param: {param:?}"); + + let block_inputs: Vec = input + .proofs + .iter() + .map(|proof| proof.input.unwrap()) + .collect::>(); + let block_proof_vk = serde_json::from_str::( + &input.proofs.first().unwrap().uuid.clone().unwrap(), + ) + .map_err(|e| ProverError::GuestError(format!("Failed to parse SP1 vk: {e}")))?; + let stark_vk = block_proof_vk.vk.clone(); + let image_id = block_proof_vk.hash_u32(); + let aggregation_input = ZkAggregationGuestInput { + image_id, + block_inputs, + }; + info!( + "Aggregating {:?} proofs with input: {aggregation_input:?}", + input.proofs.len(), + ); + + let mut stdin = SP1Stdin::new(); + stdin.write(&aggregation_input); + for proof in input.proofs.iter() { + let sp1_proof = serde_json::from_str::(&proof.quote.clone().unwrap()) + .map_err(|e| ProverError::GuestError(format!("Failed to parse SP1 proof: {e}")))?; + match sp1_proof { + SP1Proof::Compressed(block_proof) => { + stdin.write_proof(block_proof, stark_vk.clone()); + } + _ => { + error!("unsupported proof type for aggregation: {sp1_proof:?}"); + } + } + } + + // Generate the proof for the given program. + let client = param + .prover + .map(|mode| match mode { + ProverMode::Mock => ProverClient::mock(), + ProverMode::Local => ProverClient::local(), + ProverMode::Network => ProverClient::network(), + }) + .unwrap_or_else(ProverClient::new); + + let (pk, vk) = client.setup(AGGREGATION_ELF); + info!( + "sp1 aggregate: {:?} based {:?} blocks with vk {:?}", + reth_primitives::hex::encode_prefixed(stark_vk.hash_bytes()), + input.proofs.len(), + vk.bytes32() + ); + + let prove_result = client + .prove(&pk, stdin) + .plonk() + .run() + .expect("proving failed"); + + let proof_bytes = prove_result.bytes(); + if param.verify { + let time = Measurement::start("verify", false); + let aggregation_pi = prove_result.clone().borrow_mut().public_values.raw(); + let fixture = RaikoProofFixture { + vkey: vk.bytes32().to_string(), + public_values: reth_primitives::hex::encode_prefixed(aggregation_pi), + proof: reth_primitives::hex::encode_prefixed(&proof_bytes), + }; + + verify_sol(&fixture)?; + time.stop_with("==> Verification complete"); + } + + let proof = (!proof_bytes.is_empty()).then_some( + // 0x + 64 bytes of the vkey + the proof + // vkey itself contains 0x prefix + format!( + "{}{}{}", + vk.bytes32(), + reth_primitives::hex::encode(stark_vk.hash_bytes()), + reth_primitives::hex::encode(proof_bytes) + ), + ); + + Ok::<_, ProverError>( + Sp1Response { + proof, + sp1_proof: None, + vkey: None, + } + .into(), + ) + } } fn get_env_mock() -> ProverMode { @@ -225,13 +364,65 @@ fn get_env_mock() -> ProverMode { } } +fn init_verifier() -> Result { + // In cargo run, Cargo sets the working directory to the root of the workspace + let contract_path = &*CONTRACT_PATH; + info!("Contract dir: {contract_path:?}"); + let artifacts_dir = sp1_sdk::install::try_install_circuit_artifacts(); + // Create the destination directory if it doesn't exist + fs::create_dir_all(contract_path)?; + + // Read the entries in the source directory + for entry in fs::read_dir(artifacts_dir)? { + let entry = entry?; + let src = entry.path(); + + // Check if the entry is a file and ends with .sol + if src.is_file() && src.extension().map(|s| s == "sol").unwrap_or(false) { + let out = contract_path.join(src.file_name().unwrap()); + fs::copy(&src, &out)?; + println!("Copied: {:?}", src.file_name().unwrap()); + } + } + Ok(contract_path.clone()) +} + /// A fixture that can be used to test the verification of SP1 zkVM proofs inside Solidity. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub(crate) struct RaikoProofFixture { vkey: String, - public_values: B256, - proof: Vec, + public_values: String, + proof: String, +} + +fn verify_sol(fixture: &RaikoProofFixture) -> ProverResult<()> { + assert!(VERIFIER.is_ok()); + debug!("===> Fixture: {fixture:#?}"); + + // Save the fixture to a file. + let fixture_path = &*FIXTURE_PATH; + info!("Writing fixture to: {fixture_path:?}"); + + if !fixture_path.exists() { + std::fs::create_dir_all(fixture_path.clone()) + .map_err(|e| ProverError::GuestError(format!("Failed to create fixture path: {e}")))?; + } + std::fs::write( + fixture_path.join("fixture.json"), + serde_json::to_string_pretty(&fixture).unwrap(), + ) + .map_err(|e| ProverError::GuestError(format!("Failed to write fixture: {e}")))?; + + let child = std::process::Command::new("forge") + .arg("test") + .current_dir(&*CONTRACT_PATH) + .stdout(std::process::Stdio::inherit()) // Inherit the parent process' stdout + .spawn(); + info!("Verification started {child:?}"); + child.map_err(|e| ProverError::GuestError(format!("Failed to run forge: {e}")))?; + + Ok(()) } #[cfg(test)] @@ -261,6 +452,11 @@ mod test { println!("{json:?} {deserialized:?}"); } + #[test] + fn test_init_verifier() { + VERIFIER.as_ref().expect("Failed to init verifier"); + } + #[test] fn run_unittest_elf() { // TODO(Cecilia): imple GuestInput::mock() for unit test diff --git a/provers/sp1/driver/src/proof_verify/remote_contract_verify.rs b/provers/sp1/driver/src/proof_verify/remote_contract_verify.rs index 7474fad4e..6606a041d 100644 --- a/provers/sp1/driver/src/proof_verify/remote_contract_verify.rs +++ b/provers/sp1/driver/src/proof_verify/remote_contract_verify.rs @@ -32,13 +32,11 @@ pub(crate) async fn verify_sol_by_contract_call(fixture: &RaikoProofFixture) -> let provider = ProviderBuilder::new().on_http(Url::parse(&sp1_verifier_rpc_url).unwrap()); let program_key: B256 = B256::from_str(&fixture.vkey).unwrap(); - let public_value = fixture.public_values; + let public_value = fixture.public_values.clone(); let proof_bytes = fixture.proof.clone(); info!( - "verify sp1 proof with program key: {:?} public value: {:?} proof: {:?}", - program_key, - public_value, + "verify sp1 proof with program key: {program_key:?} public value: {public_value:?} proof: {:?}", reth_primitives::hex::encode(&proof_bytes) ); @@ -50,7 +48,7 @@ pub(crate) async fn verify_sol_by_contract_call(fixture: &RaikoProofFixture) -> if verify_call_res.is_ok() { info!("SP1 proof verified successfully using {sp1_verifier_addr:?}!"); } else { - error!("SP1 proof verification failed: {:?}!", verify_call_res); + error!("SP1 proof verification failed: {verify_call_res:?}!"); } Ok(()) diff --git a/provers/sp1/driver/src/verifier.rs b/provers/sp1/driver/src/verifier.rs index f1f2454c9..20c760e92 100644 --- a/provers/sp1/driver/src/verifier.rs +++ b/provers/sp1/driver/src/verifier.rs @@ -31,7 +31,7 @@ async fn main() { } }) .unwrap_or_else(|| PathBuf::from(DATA).join("taiko_mainnet-328833.json")); - println!("Reading GuestInput from {:?}", path); + println!("Reading GuestInput from {path:?}"); let json = std::fs::read_to_string(path).unwrap(); // Deserialize the input. diff --git a/provers/sp1/guest/Cargo.lock b/provers/sp1/guest/Cargo.lock index 3f00879a3..cfa9f3a4d 100644 --- a/provers/sp1/guest/Cargo.lock +++ b/provers/sp1/guest/Cargo.lock @@ -3504,7 +3504,7 @@ dependencies = [ "size", "snowbridge-amcl", "sp1-derive", - "sp1-primitives", + "sp1-primitives 1.1.1", "static_assertions", "strum", "strum_macros", @@ -3549,8 +3549,9 @@ dependencies = [ [[package]] name = "sp1-lib" -version = "1.2.0-rc1" -source = "git+https://github.com/succinctlabs/sp1?branch=dev#e8efd0019c8be52c6c4cecfea6259ab90db4148a" +version = "1.2.0-rc2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b85660c40c7b40a65c706816d9157ef1b084099a80275c9b4d650f53067e667f" dependencies = [ "anyhow", "bincode", @@ -3562,9 +3563,9 @@ dependencies = [ [[package]] name = "sp1-lib" -version = "1.2.0-rc2" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b85660c40c7b40a65c706816d9157ef1b084099a80275c9b4d650f53067e667f" +checksum = "413956de14568d7fb462213b9505ad4607d75c875301b9eca567cfb2e58eaac1" dependencies = [ "anyhow", "bincode", @@ -3588,10 +3589,25 @@ dependencies = [ "p3-symmetric", ] +[[package]] +name = "sp1-primitives" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efbeba375fe59917f162f1808c280d2e39e4698dc7eeac83936b6e70c2f8dbbc" +dependencies = [ + "itertools 0.13.0", + "lazy_static", + "p3-baby-bear", + "p3-field", + "p3-poseidon2", + "p3-symmetric", +] + [[package]] name = "sp1-zkvm" -version = "1.2.0-rc1" -source = "git+https://github.com/succinctlabs/sp1?branch=dev#e8efd0019c8be52c6c4cecfea6259ab90db4148a" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66c525f67cfd3f65950f01c713a72c41a5d44d289155644c8ace4ec264098039" dependencies = [ "bincode", "cfg-if", @@ -3599,10 +3615,13 @@ dependencies = [ "lazy_static", "libm", "once_cell", + "p3-baby-bear", + "p3-field", "rand", "serde", "sha2", - "sp1-lib 1.2.0-rc1", + "sp1-lib 2.0.0", + "sp1-primitives 2.0.0", ] [[package]] diff --git a/provers/sp1/guest/Cargo.toml b/provers/sp1/guest/Cargo.toml index efa74446b..3063cc5be 100644 --- a/provers/sp1/guest/Cargo.toml +++ b/provers/sp1/guest/Cargo.toml @@ -9,6 +9,10 @@ edition = "2021" name = "zk_op" path = "src/zk_op.rs" +[[bin]] +name = "sp1-aggregation" +path = "src/aggregation.rs" + [[bin]] name = "sha256" path = "src/benchmark/sha256.rs" @@ -33,9 +37,9 @@ path = "src/benchmark/bn254_mul.rs" [dependencies] raiko-lib = { path = "../../../lib", features = ["std", "sp1"] } -sp1-zkvm = { git = "https://github.com/succinctlabs/sp1", branch = "dev" } -sp1-core = { version = "1.1.1"} -sha2-v0-10-8 = { git = "https://github.com/sp1-patches/RustCrypto-hashes", package = "sha2", branch = "patch-v0.10.8" } +sp1-zkvm = { version = "2.0.0", features = ["verify"] } +sp1-core = { version = "1.1.1" } +sha2 = { git = "https://github.com/sp1-patches/RustCrypto-hashes", package = "sha2", branch = "patch-v0.10.8" } secp256k1 = { git = "https://github.com/sp1-patches/rust-secp256k1", branch = "patch-secp256k1-v0.29.0" } harness-core = { path = "../../../harness/core" } harness = { path = "../../../harness/macro", features = ["sp1"] } @@ -46,7 +50,10 @@ revm-precompile = { git = "https://github.com/taikoxyz/revm.git", branch = "v36- "c-kzg", ] } bincode = "1.3.3" -reth-primitives = { git = "https://github.com/taikoxyz/taiko-reth.git", branch = "v1.0.0-rc.2-taiko", default-features = false, features = ["alloy-compat", "taiko"] } +reth-primitives = { git = "https://github.com/taikoxyz/taiko-reth.git", branch = "v1.0.0-rc.2-taiko", default-features = false, features = [ + "alloy-compat", + "taiko", +] } lazy_static = "1.4.0" num-bigint = { version = "0.4.6", default-features = false } diff --git a/provers/sp1/guest/elf/sp1-aggregation b/provers/sp1/guest/elf/sp1-aggregation new file mode 100755 index 000000000..ed3c2c31d Binary files /dev/null and b/provers/sp1/guest/elf/sp1-aggregation differ diff --git a/provers/sp1/guest/src/aggregation.rs b/provers/sp1/guest/src/aggregation.rs new file mode 100644 index 000000000..84d4bde31 --- /dev/null +++ b/provers/sp1/guest/src/aggregation.rs @@ -0,0 +1,31 @@ +//! Aggregates multiple block proofs +#![no_main] +sp1_zkvm::entrypoint!(main); + +use sha2::{Digest, Sha256}; + +use raiko_lib::{ + input::ZkAggregationGuestInput, + primitives::B256, + protocol_instance::{aggregation_output, words_to_bytes_be}, +}; + +pub fn main() { + // Read the aggregation input + let input = sp1_zkvm::io::read::(); + + // Verify the block proofs. + for block_input in input.block_inputs.iter() { + sp1_zkvm::lib::verify::verify_sp1_proof( + &input.image_id, + &Sha256::digest(block_input).into(), + ); + } + + // The aggregation output + sp1_zkvm::io::commit_slice(&aggregation_output( + B256::from(words_to_bytes_be(&input.image_id)), + input.block_inputs, + )); +} + diff --git a/provers/sp1/guest/src/benchmark/bn254_add.rs b/provers/sp1/guest/src/benchmark/bn254_add.rs index 096b65468..1f5729639 100644 --- a/provers/sp1/guest/src/benchmark/bn254_add.rs +++ b/provers/sp1/guest/src/benchmark/bn254_add.rs @@ -17,11 +17,11 @@ fn main() { ]); let op = Sp1Operator {}; - + let ct = CycleTracker::start("bn128_run_add"); let res = op.bn128_run_add(&input).unwrap(); ct.end(); - + let hi = res[..32].to_vec(); let lo = res[32..].to_vec(); diff --git a/provers/sp1/guest/src/benchmark/bn254_mul.rs b/provers/sp1/guest/src/benchmark/bn254_mul.rs index 664947de0..ae1ede10e 100644 --- a/provers/sp1/guest/src/benchmark/bn254_mul.rs +++ b/provers/sp1/guest/src/benchmark/bn254_mul.rs @@ -19,7 +19,7 @@ fn main() { let ct = CycleTracker::start("bn128_run_mul"); let res = op.bn128_run_mul(&input).unwrap(); ct.end(); - + let hi = res[..32].to_vec(); let lo = res[32..].to_vec(); sp1_zkvm::io::commit(&hi); diff --git a/provers/sp1/guest/src/benchmark/sha256.rs b/provers/sp1/guest/src/benchmark/sha256.rs index 9c5908b13..e6c574333 100644 --- a/provers/sp1/guest/src/benchmark/sha256.rs +++ b/provers/sp1/guest/src/benchmark/sha256.rs @@ -13,7 +13,7 @@ fn main() { ]); let op = Sp1Operator {}; - + let ct = CycleTracker::start("sha256_run"); let res = op.sha256_run(&input).unwrap(); ct.end(); diff --git a/provers/sp1/guest/src/sys.rs b/provers/sp1/guest/src/sys.rs index 04a3c18d7..f9eed1c93 100644 --- a/provers/sp1/guest/src/sys.rs +++ b/provers/sp1/guest/src/sys.rs @@ -39,4 +39,5 @@ pub unsafe extern "C" fn free(_size: *const c_void) { #[no_mangle] pub extern "C" fn __ctzsi2(x: u32) -> u32 { x.trailing_zeros() -} \ No newline at end of file +} + diff --git a/provers/sp1/guest/src/zk_op.rs b/provers/sp1/guest/src/zk_op.rs index e6ed28be4..b28be5e20 100644 --- a/provers/sp1/guest/src/zk_op.rs +++ b/provers/sp1/guest/src/zk_op.rs @@ -1,15 +1,14 @@ -use num_bigint::BigUint; use ::secp256k1::SECP256K1; +use num_bigint::BigUint; use reth_primitives::public_key_to_address; use revm_precompile::{bn128::ADD_INPUT_LEN, utilities::right_pad, zk_op::ZkvmOperator, Error}; use secp256k1::{ ecdsa::{RecoverableSignature, RecoveryId}, Message, }; -use sha2_v0_10_8 as sp1_sha2; +use sha2 as sp1_sha2; use sp1_core::utils::ec::{weierstrass::bn254::Bn254, AffinePoint}; - #[derive(Debug)] pub struct Sp1Operator; @@ -117,7 +116,7 @@ harness::zk_suits!( p.x().to_big_endian(&mut p_x).unwrap(); p.y().to_big_endian(&mut p_y).unwrap(); - println!("{:?}, {:?}:?", p_x, p_y); + println!("{p_x:?}, {p_y:?}:?"); // Deserialize AffinePoint in Sp1 let p = be_bytes_to_point(&input); @@ -154,4 +153,4 @@ harness::zk_suits!( assert!(G1_LE == [p.x.to_bytes_le(), p.y.to_bytes_le()].concat()); } } -); \ No newline at end of file +); diff --git a/script/prove-block.sh b/script/prove-block.sh index 7b0d387e7..8e3113e89 100755 --- a/script/prove-block.sh +++ b/script/prove-block.sh @@ -58,6 +58,16 @@ elif [ "$proof" == "sp1" ]; then "verify": false } ' +elif [ "$proof" == "sp1-aggregation" ]; then + proofParam=' + "proof_type": "sp1", + "blob_proof_type": "proof_of_equivalence", + "sp1": { + "recursion": "compressed", + "prover": "network", + "verify": false + } + ' elif [ "$proof" == "sgx" ]; then proofParam=' "proof_type": "sgx", @@ -134,13 +144,13 @@ for block in $(eval echo {$rangeStart..$rangeEnd}); do fi echo "- proving block $block" - curl --location --request POST 'http://localhost:8080/proof' \ + curl --location --request POST 'http://localhost:8080/v3/proof' \ --header 'Content-Type: application/json' \ --header 'Authorization: Bearer 4cbd753fbcbc2639de804f8ce425016a50e0ecd53db00cb5397912e83f5e570e' \ --data-raw "{ \"network\": \"$chain\", \"l1_network\": \"$l1_network\", - \"block_number\": $block, + \"block_numbers\": [[$block, null], [$(($block+1)), null]], \"prover\": \"$prover\", \"graffiti\": \"$graffiti\", $proofParam diff --git a/tasks/src/adv_sqlite.rs b/tasks/src/adv_sqlite.rs index 120c0d43d..96f9e4bb3 100644 --- a/tasks/src/adv_sqlite.rs +++ b/tasks/src/adv_sqlite.rs @@ -159,6 +159,7 @@ use std::{ }; use chrono::{DateTime, Utc}; +use raiko_core::interfaces::AggregationOnlyRequest; use raiko_lib::{ primitives::B256, prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult}, @@ -575,7 +576,7 @@ impl TaskDb { ":blockhash": blockhash.to_vec(), ":proofsys_id": proof_system as u8, ":prover": prover, - ":status_id": status as i32, + ":status_id": i32::from(status), ":proof": proof.map(hex::encode) })?; @@ -943,6 +944,36 @@ impl TaskManager for SqliteTaskManager { let task_db = self.arc_task_db.lock().await; task_db.list_stored_ids() } + + async fn enqueue_aggregation_task( + &mut self, + _request: &AggregationOnlyRequest, + ) -> TaskManagerResult<()> { + todo!() + } + + async fn get_aggregation_task_proving_status( + &mut self, + _request: &AggregationOnlyRequest, + ) -> TaskManagerResult { + todo!() + } + + async fn update_aggregation_task_progress( + &mut self, + _request: &AggregationOnlyRequest, + _status: TaskStatus, + _proof: Option<&[u8]>, + ) -> TaskManagerResult<()> { + todo!() + } + + async fn get_aggregation_task_proof( + &mut self, + _request: &AggregationOnlyRequest, + ) -> TaskManagerResult> { + todo!() + } } #[cfg(test)] diff --git a/tasks/src/lib.rs b/tasks/src/lib.rs index 2abd2e741..cc7523e35 100644 --- a/tasks/src/lib.rs +++ b/tasks/src/lib.rs @@ -4,8 +4,7 @@ use std::{ }; use chrono::{DateTime, Utc}; -use num_enum::{FromPrimitive, IntoPrimitive}; -use raiko_core::interfaces::ProofType; +use raiko_core::interfaces::{AggregationOnlyRequest, ProofType}; use raiko_lib::{ primitives::{ChainId, B256}, prover::{IdStore, IdWrite, ProofKey, ProverResult}, @@ -61,24 +60,83 @@ impl From for TaskManagerError { #[allow(non_camel_case_types)] #[rustfmt::skip] -#[derive(PartialEq, Debug, Copy, Clone, IntoPrimitive, FromPrimitive, Deserialize, Serialize, ToSchema)] -#[repr(i32)] +#[derive(PartialEq, Debug, Clone, Deserialize, Serialize, ToSchema, Eq, PartialOrd, Ord)] #[serde(rename_all = "snake_case")] pub enum TaskStatus { - Success = 0, - Registered = 1000, - WorkInProgress = 2000, - ProofFailure_Generic = -1000, - ProofFailure_OutOfMemory = -1100, - NetworkFailure = -2000, - Cancelled = -3000, - Cancelled_NeverStarted = -3100, - Cancelled_Aborted = -3200, - CancellationInProgress = -3210, - InvalidOrUnsupportedBlock = -4000, - UnspecifiedFailureReason = -9999, - #[num_enum(default)] - SqlDbCorruption = -99999, + Success, + Registered, + WorkInProgress, + ProofFailure_Generic, + ProofFailure_OutOfMemory, + NetworkFailure, + Cancelled, + Cancelled_NeverStarted, + Cancelled_Aborted, + CancellationInProgress, + InvalidOrUnsupportedBlock, + NonDbFailure(String), + UnspecifiedFailureReason, + SqlDbCorruption, +} + +impl From for i32 { + fn from(status: TaskStatus) -> i32 { + match status { + TaskStatus::Success => 0, + TaskStatus::Registered => 1000, + TaskStatus::WorkInProgress => 2000, + TaskStatus::ProofFailure_Generic => -1000, + TaskStatus::ProofFailure_OutOfMemory => -1100, + TaskStatus::NetworkFailure => -2000, + TaskStatus::Cancelled => -3000, + TaskStatus::Cancelled_NeverStarted => -3100, + TaskStatus::Cancelled_Aborted => -3200, + TaskStatus::CancellationInProgress => -3210, + TaskStatus::InvalidOrUnsupportedBlock => -4000, + TaskStatus::NonDbFailure(_) => -5000, + TaskStatus::UnspecifiedFailureReason => -9999, + TaskStatus::SqlDbCorruption => -99999, + } + } +} + +impl From for TaskStatus { + fn from(value: i32) -> TaskStatus { + match value { + 0 => TaskStatus::Success, + 1000 => TaskStatus::Registered, + 2000 => TaskStatus::WorkInProgress, + -1000 => TaskStatus::ProofFailure_Generic, + -1100 => TaskStatus::ProofFailure_OutOfMemory, + -2000 => TaskStatus::NetworkFailure, + -3000 => TaskStatus::Cancelled, + -3100 => TaskStatus::Cancelled_NeverStarted, + -3200 => TaskStatus::Cancelled_Aborted, + -3210 => TaskStatus::CancellationInProgress, + -4000 => TaskStatus::InvalidOrUnsupportedBlock, + -5000 => TaskStatus::NonDbFailure("".to_string()), + -9999 => TaskStatus::UnspecifiedFailureReason, + -99999 => TaskStatus::SqlDbCorruption, + _ => TaskStatus::UnspecifiedFailureReason, + } + } +} + +impl FromIterator for TaskStatus { + fn from_iter>(iter: T) -> Self { + iter.into_iter() + .min() + .unwrap_or(TaskStatus::UnspecifiedFailureReason) + } +} + +impl<'a> FromIterator<&'a TaskStatus> for TaskStatus { + fn from_iter>(iter: T) -> Self { + iter.into_iter() + .min() + .cloned() + .unwrap_or(TaskStatus::UnspecifiedFailureReason) + } } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] @@ -166,6 +224,32 @@ pub trait TaskManager: IdStore + IdWrite { /// List all stored ids. async fn list_stored_ids(&mut self) -> TaskManagerResult>; + + /// Enqueue a new aggregation task to the tasks database. + async fn enqueue_aggregation_task( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult<()>; + + /// Update a specific aggregation tasks progress. + async fn update_aggregation_task_progress( + &mut self, + request: &AggregationOnlyRequest, + status: TaskStatus, + proof: Option<&[u8]>, + ) -> TaskManagerResult<()>; + + /// Returns the latest triplet (status, proof - if any, last update time). + async fn get_aggregation_task_proving_status( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult; + + /// Returns the proof for the given aggregation task. + async fn get_aggregation_task_proof( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult>; } pub fn ensure(expression: bool, message: &str) -> TaskManagerResult<()> { @@ -297,6 +381,68 @@ impl TaskManager for TaskManagerWrapper { TaskManagerInstance::Sqlite(manager) => manager.list_stored_ids().await, } } + + async fn enqueue_aggregation_task( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult<()> { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => { + manager.enqueue_aggregation_task(request).await + } + TaskManagerInstance::Sqlite(ref mut manager) => { + manager.enqueue_aggregation_task(request).await + } + } + } + + async fn update_aggregation_task_progress( + &mut self, + request: &AggregationOnlyRequest, + status: TaskStatus, + proof: Option<&[u8]>, + ) -> TaskManagerResult<()> { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => { + manager + .update_aggregation_task_progress(request, status, proof) + .await + } + TaskManagerInstance::Sqlite(ref mut manager) => { + manager + .update_aggregation_task_progress(request, status, proof) + .await + } + } + } + + async fn get_aggregation_task_proving_status( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => { + manager.get_aggregation_task_proving_status(request).await + } + TaskManagerInstance::Sqlite(ref mut manager) => { + manager.get_aggregation_task_proving_status(request).await + } + } + } + + async fn get_aggregation_task_proof( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult> { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => { + manager.get_aggregation_task_proof(request).await + } + TaskManagerInstance::Sqlite(ref mut manager) => { + manager.get_aggregation_task_proof(request).await + } + } + } } pub fn get_task_manager(opts: &TaskManagerOpts) -> TaskManagerWrapper { diff --git a/tasks/src/mem_db.rs b/tasks/src/mem_db.rs index ad6550004..f3bee7883 100644 --- a/tasks/src/mem_db.rs +++ b/tasks/src/mem_db.rs @@ -13,6 +13,7 @@ use std::{ }; use chrono::Utc; +use raiko_core::interfaces::AggregationOnlyRequest; use raiko_lib::prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult}; use tokio::sync::Mutex; use tracing::{debug, info}; @@ -29,14 +30,16 @@ pub struct InMemoryTaskManager { #[derive(Debug)] pub struct InMemoryTaskDb { - enqueue_task: HashMap, + tasks_queue: HashMap, + aggregation_tasks_queue: HashMap, store: HashMap, } impl InMemoryTaskDb { fn new() -> InMemoryTaskDb { InMemoryTaskDb { - enqueue_task: HashMap::new(), + tasks_queue: HashMap::new(), + aggregation_tasks_queue: HashMap::new(), store: HashMap::new(), } } @@ -44,7 +47,7 @@ impl InMemoryTaskDb { fn enqueue_task(&mut self, key: &TaskDescriptor) { let task_status = (TaskStatus::Registered, None, Utc::now()); - match self.enqueue_task.get(key) { + match self.tasks_queue.get(key) { Some(task_proving_records) => { debug!( "Task already exists: {:?}", @@ -53,7 +56,7 @@ impl InMemoryTaskDb { } // do nothing None => { info!("Enqueue new task: {key:?}"); - self.enqueue_task.insert(key.clone(), vec![task_status]); + self.tasks_queue.insert(key.clone(), vec![task_status]); } } } @@ -64,9 +67,9 @@ impl InMemoryTaskDb { status: TaskStatus, proof: Option<&[u8]>, ) -> TaskManagerResult<()> { - ensure(self.enqueue_task.contains_key(&key), "no task found")?; + ensure(self.tasks_queue.contains_key(&key), "no task found")?; - self.enqueue_task.entry(key).and_modify(|entry| { + self.tasks_queue.entry(key).and_modify(|entry| { if let Some(latest) = entry.last() { if latest.0 != status { entry.push((status, proof.map(hex::encode), Utc::now())); @@ -81,14 +84,14 @@ impl InMemoryTaskDb { &mut self, key: &TaskDescriptor, ) -> TaskManagerResult { - Ok(self.enqueue_task.get(key).cloned().unwrap_or_default()) + Ok(self.tasks_queue.get(key).cloned().unwrap_or_default()) } fn get_task_proof(&mut self, key: &TaskDescriptor) -> TaskManagerResult> { - ensure(self.enqueue_task.contains_key(key), "no task found")?; + ensure(self.tasks_queue.contains_key(key), "no task found")?; let proving_status_records = self - .enqueue_task + .tasks_queue .get(key) .ok_or_else(|| TaskManagerError::SqlError("no task in db".to_owned()))?; @@ -107,20 +110,22 @@ impl InMemoryTaskDb { } fn size(&mut self) -> TaskManagerResult<(usize, Vec<(String, usize)>)> { - Ok((self.enqueue_task.len(), vec![])) + Ok((self.tasks_queue.len(), vec![])) } fn prune(&mut self) -> TaskManagerResult<()> { - self.enqueue_task.clear(); + self.tasks_queue.clear(); Ok(()) } fn list_all_tasks(&mut self) -> TaskManagerResult> { Ok(self - .enqueue_task + .tasks_queue .iter() .flat_map(|(descriptor, statuses)| { - statuses.last().map(|status| (descriptor.clone(), status.0)) + statuses + .last() + .map(|status| (descriptor.clone(), status.0.clone())) }) .collect()) } @@ -145,6 +150,91 @@ impl InMemoryTaskDb { .cloned() .ok_or(TaskManagerError::NoData) } + + fn enqueue_aggregation_task( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult<()> { + let task_status = (TaskStatus::Registered, None, Utc::now()); + + match self.aggregation_tasks_queue.get(request) { + Some(task_proving_records) => { + debug!( + "Task already exists: {:?}", + task_proving_records.last().unwrap().0 + ); + } // do nothing + None => { + info!("Enqueue new task: {request}"); + self.aggregation_tasks_queue + .insert(request.clone(), vec![task_status]); + } + } + Ok(()) + } + + fn get_aggregation_task_proving_status( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult { + Ok(self + .aggregation_tasks_queue + .get(request) + .cloned() + .unwrap_or_default()) + } + + fn update_aggregation_task_progress( + &mut self, + request: &AggregationOnlyRequest, + status: TaskStatus, + proof: Option<&[u8]>, + ) -> TaskManagerResult<()> { + ensure( + self.aggregation_tasks_queue.contains_key(request), + "no task found", + )?; + + self.aggregation_tasks_queue + .entry(request.clone()) + .and_modify(|entry| { + if let Some(latest) = entry.last() { + if latest.0 != status { + entry.push((status, proof.map(hex::encode), Utc::now())); + } + } + }); + + Ok(()) + } + + fn get_aggregation_task_proof( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult> { + ensure( + self.aggregation_tasks_queue.contains_key(request), + "no task found", + )?; + + let proving_status_records = self + .aggregation_tasks_queue + .get(request) + .ok_or_else(|| TaskManagerError::SqlError("no task in db".to_owned()))?; + + let (_, proof, ..) = proving_status_records + .iter() + .filter(|(status, ..)| (status == &TaskStatus::Success)) + .last() + .ok_or_else(|| TaskManagerError::SqlError("no successful task in db".to_owned()))?; + + let Some(proof) = proof else { + return Ok(vec![]); + }; + + hex::decode(proof) + .map_err(|_| TaskManagerError::SqlError("couldn't decode from hex".to_owned())) + } } #[async_trait::async_trait] @@ -248,6 +338,40 @@ impl TaskManager for InMemoryTaskManager { let mut db = self.db.lock().await; db.list_stored_ids() } + + async fn enqueue_aggregation_task( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult<()> { + let mut db = self.db.lock().await; + db.enqueue_aggregation_task(request) + } + + async fn get_aggregation_task_proving_status( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult { + let mut db = self.db.lock().await; + db.get_aggregation_task_proving_status(request) + } + + async fn update_aggregation_task_progress( + &mut self, + request: &AggregationOnlyRequest, + status: TaskStatus, + proof: Option<&[u8]>, + ) -> TaskManagerResult<()> { + let mut db = self.db.lock().await; + db.update_aggregation_task_progress(request, status, proof) + } + + async fn get_aggregation_task_proof( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult> { + let mut db = self.db.lock().await; + db.get_aggregation_task_proof(request) + } } #[cfg(test)]