Skip to content

Commit

Permalink
feat(core,host): initial aggregation API (#375)
Browse files Browse the repository at this point in the history
* initial proof aggregation implementation

* aggregation improvements + risc0 aggregation

* sp1 aggregation fixes

* sp1 aggregation elf

* uuid support for risc0 aggregation

* risc0 aggregation circuit compile fixes

* fix sgx proof aggregation

* fmt

* feat(core,host): initial aggregation API

* fix(core,host,sgx): fix compiler and clippy errors

* fix(core,lib,provers): revert merge bugs and add sp1 stubs

* fix(core): remove double member

* fix(sp1): fix dependency naming

* refactor(risc0): clean up aggregation file

* fix(sp1): enable verification for proof aggregation

* feat(host): migrate to v3 API

* feat(sp1): run cargo fmt

* feat(core): make `l1_inclusion_block_number` optional

* fixproof req input into prove state manager

Signed-off-by: smtmfft <[email protected]>

* feat(core,host,lib,tasks): add aggregation tasks and API

* fix(core): fix typo

* fix v3 error return

Signed-off-by: smtmfft <[email protected]>

* feat(sp1): implement aggregate function

* fix sgx aggregation for back compatibility

Signed-off-by: smtmfft <[email protected]>

* fix(lib): fix typo

* fix risc0 aggregation

Signed-off-by: smtmfft <[email protected]>

* fix(host,sp1): handle statuses

* enable sp1 aggregation

Signed-off-by: smtmfft <[email protected]>

* feat(host): error out on empty proof array request

* fix(host): return proper status report

* feat(host,tasks): adding details to error statuses

* fix sp1 aggregation

Signed-off-by: smtmfft <[email protected]>

* update prove-block script

Signed-off-by: smtmfft <[email protected]>

* fix(fmt): run cargo fmt

* fix(clippy): fix clippy issues

* chore(repo): cleanup captured vars in format calls

* fix(sp1): convert to proper types

* chore(sp1): remove the unneccessary

---------

Signed-off-by: smtmfft <[email protected]>
Co-authored-by: Brecht Devos <[email protected]>
Co-authored-by: smtmfft <[email protected]>
  • Loading branch information
3 people authored Oct 10, 2024
1 parent 7e10837 commit eb4d032
Show file tree
Hide file tree
Showing 54 changed files with 2,195 additions and 191 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

169 changes: 166 additions & 3 deletions core/src/interfaces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ use alloy_primitives::{Address, B256};
use clap::{Args, ValueEnum};
use raiko_lib::{
consts::VerifierType,
input::{BlobProofType, GuestInput, GuestOutput},
input::{
AggregationGuestInput, AggregationGuestOutput, BlobProofType, GuestInput, GuestOutput,
},
prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverError},
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_with::{serde_as, DisplayFromStr};
use std::{collections::HashMap, path::Path, str::FromStr};
use std::{collections::HashMap, fmt::Display, path::Path, str::FromStr};
use utoipa::ToSchema;

#[derive(Debug, thiserror::Error, ToSchema)]
Expand Down Expand Up @@ -203,6 +205,47 @@ impl ProofType {
}
}

/// Run the prover driver depending on the proof type.
pub async fn aggregate_proofs(
&self,
input: AggregationGuestInput,
output: &AggregationGuestOutput,
config: &Value,
store: Option<&mut dyn IdWrite>,
) -> RaikoResult<Proof> {
let proof = match self {
ProofType::Native => NativeProver::aggregate(input.clone(), output, config, store)
.await
.map_err(<ProverError as Into<RaikoError>>::into),
ProofType::Sp1 => {
#[cfg(feature = "sp1")]
return sp1_driver::Sp1Prover::aggregate(input.clone(), output, config, store)
.await
.map_err(|e| e.into());
#[cfg(not(feature = "sp1"))]
Err(RaikoError::FeatureNotSupportedError(*self))
}
ProofType::Risc0 => {
#[cfg(feature = "risc0")]
return risc0_driver::Risc0Prover::aggregate(input.clone(), output, config, store)
.await
.map_err(|e| e.into());
#[cfg(not(feature = "risc0"))]
Err(RaikoError::FeatureNotSupportedError(*self))
}
ProofType::Sgx => {
#[cfg(feature = "sgx")]
return sgx_prover::SgxProver::aggregate(input.clone(), output, config, store)
.await
.map_err(|e| e.into());
#[cfg(not(feature = "sgx"))]
Err(RaikoError::FeatureNotSupportedError(*self))
}
}?;

Ok(proof)
}

pub async fn cancel_proof(
&self,
proof_key: ProofKey,
Expand Down Expand Up @@ -302,7 +345,7 @@ pub struct ProofRequestOpt {
pub prover_args: ProverSpecificOpts,
}

#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema, Args)]
#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema, Args, PartialEq, Eq, Hash)]
pub struct ProverSpecificOpts {
/// Native prover specific options.
pub native: Option<Value>,
Expand Down Expand Up @@ -398,3 +441,123 @@ impl TryFrom<ProofRequestOpt> for ProofRequest {
})
}
}

#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema)]
#[serde(default)]
/// A request for proof aggregation of multiple proofs.
pub struct AggregationRequest {
/// The block numbers and l1 inclusion block numbers for the blocks to aggregate proofs for.
pub block_numbers: Vec<(u64, Option<u64>)>,
/// The network to generate the proof for.
pub network: Option<String>,
/// The L1 network to generate the proof for.
pub l1_network: Option<String>,
// Graffiti.
pub graffiti: Option<String>,
/// The protocol instance data.
pub prover: Option<String>,
/// The proof type.
pub proof_type: Option<String>,
/// Blob proof type.
pub blob_proof_type: Option<String>,
#[serde(flatten)]
/// Any additional prover params in JSON format.
pub prover_args: ProverSpecificOpts,
}

impl AggregationRequest {
/// Merge proof request options into aggregation request options.
pub fn merge(&mut self, opts: &ProofRequestOpt) -> RaikoResult<()> {
let this = serde_json::to_value(&self)?;
let mut opts = serde_json::to_value(opts)?;
merge(&mut opts, &this);
*self = serde_json::from_value(opts)?;
Ok(())
}
}

impl From<AggregationRequest> for Vec<ProofRequestOpt> {
fn from(value: AggregationRequest) -> Self {
value
.block_numbers
.iter()
.map(
|&(block_number, l1_inclusion_block_number)| ProofRequestOpt {
block_number: Some(block_number),
l1_inclusion_block_number,
network: value.network.clone(),
l1_network: value.l1_network.clone(),
graffiti: value.graffiti.clone(),
prover: value.prover.clone(),
proof_type: value.proof_type.clone(),
blob_proof_type: value.blob_proof_type.clone(),
prover_args: value.prover_args.clone(),
},
)
.collect()
}
}

impl From<ProofRequestOpt> for AggregationRequest {
fn from(value: ProofRequestOpt) -> Self {
let block_numbers = if let Some(block_number) = value.block_number {
vec![(block_number, value.l1_inclusion_block_number)]
} else {
vec![]
};

Self {
block_numbers,
network: value.network,
l1_network: value.l1_network,
graffiti: value.graffiti,
prover: value.prover,
proof_type: value.proof_type,
blob_proof_type: value.blob_proof_type,
prover_args: value.prover_args,
}
}
}

#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema, PartialEq, Eq, Hash)]
#[serde(default)]
/// A request for proof aggregation of multiple proofs.
pub struct AggregationOnlyRequest {
/// The block numbers and l1 inclusion block numbers for the blocks to aggregate proofs for.
pub proofs: Vec<Proof>,
/// The proof type.
pub proof_type: Option<String>,
#[serde(flatten)]
/// Any additional prover params in JSON format.
pub prover_args: ProverSpecificOpts,
}

impl Display for AggregationOnlyRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&format!(
"AggregationOnlyRequest {{ {:?}, {:?} }}",
self.proof_type, self.prover_args
))
}
}

impl From<(AggregationRequest, Vec<Proof>)> for AggregationOnlyRequest {
fn from((request, proofs): (AggregationRequest, Vec<Proof>)) -> Self {
Self {
proofs,
proof_type: request.proof_type,
prover_args: request.prover_args,
}
}
}

impl AggregationOnlyRequest {
/// Merge proof request options into aggregation request options.
pub fn merge(&mut self, opts: &ProofRequestOpt) -> RaikoResult<()> {
let this = serde_json::to_value(&self)?;
let mut opts = serde_json::to_value(opts)?;
merge(&mut opts, &this);
*self = serde_json::from_value(opts)?;
Ok(())
}
}
71 changes: 59 additions & 12 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,9 @@ mod tests {
use clap::ValueEnum;
use raiko_lib::{
consts::{Network, SupportedChainSpecs},
input::BlobProofType,
input::{AggregationGuestInput, AggregationGuestOutput, BlobProofType},
primitives::B256,
prover::Proof,
};
use serde_json::{json, Value};
use std::{collections::HashMap, env};
Expand All @@ -242,7 +243,7 @@ mod tests {
ci == "1"
}

fn test_proof_params() -> HashMap<String, Value> {
fn test_proof_params(enable_aggregation: bool) -> HashMap<String, Value> {
let mut prover_args = HashMap::new();
prover_args.insert(
"native".to_string(),
Expand All @@ -256,7 +257,7 @@ mod tests {
"sp1".to_string(),
json! {
{
"recursion": "core",
"recursion": if enable_aggregation { "compressed" } else { "plonk" },
"prover": "mock",
"verify": true
}
Expand All @@ -278,8 +279,8 @@ mod tests {
json! {
{
"instance_id": 121,
"setup": true,
"bootstrap": true,
"setup": enable_aggregation,
"bootstrap": enable_aggregation,
"prove": true,
}
},
Expand All @@ -291,7 +292,7 @@ mod tests {
l1_chain_spec: ChainSpec,
taiko_chain_spec: ChainSpec,
proof_request: ProofRequest,
) {
) -> Proof {
let provider =
RpcBlockDataProvider::new(&taiko_chain_spec.rpc, proof_request.block_number - 1)
.expect("Could not create RpcBlockDataProvider");
Expand All @@ -301,10 +302,10 @@ mod tests {
.await
.expect("input generation failed");
let output = raiko.get_output(&input).expect("output generation failed");
let _proof = raiko
raiko
.prove(input, &output, None)
.await
.expect("proof generation failed");
.expect("proof generation failed")
}

#[ignore]
Expand Down Expand Up @@ -332,7 +333,7 @@ mod tests {
l1_network,
proof_type,
blob_proof_type: BlobProofType::ProofOfEquivalence,
prover_args: test_proof_params(),
prover_args: test_proof_params(false),
};
prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await;
}
Expand Down Expand Up @@ -361,7 +362,7 @@ mod tests {
l1_network,
proof_type,
blob_proof_type: BlobProofType::ProofOfEquivalence,
prover_args: test_proof_params(),
prover_args: test_proof_params(false),
};
prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await;
}
Expand Down Expand Up @@ -399,7 +400,7 @@ mod tests {
l1_network,
proof_type,
blob_proof_type: BlobProofType::ProofOfEquivalence,
prover_args: test_proof_params(),
prover_args: test_proof_params(false),
};
prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await;
}
Expand Down Expand Up @@ -432,9 +433,55 @@ mod tests {
l1_network,
proof_type,
blob_proof_type: BlobProofType::ProofOfEquivalence,
prover_args: test_proof_params(),
prover_args: test_proof_params(false),
};
prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await;
}
}

#[tokio::test(flavor = "multi_thread")]
async fn test_prove_block_taiko_a7_aggregated() {
let proof_type = get_proof_type_from_env();
let l1_network = Network::Holesky.to_string();
let network = Network::TaikoA7.to_string();
// Give the CI an simpler block to test because it doesn't have enough memory.
// Unfortunately that also means that kzg is not getting fully verified by CI.
let block_number = if is_ci() { 105987 } else { 101368 };
let taiko_chain_spec = SupportedChainSpecs::default()
.get_chain_spec(&network)
.unwrap();
let l1_chain_spec = SupportedChainSpecs::default()
.get_chain_spec(&l1_network)
.unwrap();

let proof_request = ProofRequest {
block_number,
l1_inclusion_block_number: 0,
network,
graffiti: B256::ZERO,
prover: Address::ZERO,
l1_network,
proof_type,
blob_proof_type: BlobProofType::ProofOfEquivalence,
prover_args: test_proof_params(true),
};
let proof = prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await;

let input = AggregationGuestInput {
proofs: vec![proof.clone(), proof],
};

let output = AggregationGuestOutput { hash: B256::ZERO };

let aggregated_proof = proof_type
.aggregate_proofs(
input,
&output,
&serde_json::to_value(&test_proof_params(false)).unwrap(),
None,
)
.await
.expect("proof aggregation failed");
println!("aggregated proof: {aggregated_proof:?}");
}
}
5 changes: 1 addition & 4 deletions core/src/preflight/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,8 @@ pub async fn prepare_taiko_chain_input(
RaikoError::Preflight("No L1 inclusion block hash for the requested block".to_owned())
})?;
info!(
"L1 inclusion block number: {:?}, hash: {:?}. L1 state block number: {:?}, hash: {:?}",
l1_inclusion_block_number,
l1_inclusion_block_hash,
"L1 inclusion block number: {l1_inclusion_block_number:?}, hash: {l1_inclusion_block_hash:?}. L1 state block number: {:?}, hash: {l1_state_block_hash:?}",
l1_state_header.number,
l1_state_block_hash
);

// Fetch the tx data from either calldata or blobdata
Expand Down
14 changes: 14 additions & 0 deletions core/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,28 @@ impl Prover for NativeProver {
}

Ok(Proof {
input: None,
proof: None,
quote: None,
uuid: None,
kzg_proof: None,
})
}

async fn cancel(_proof_key: ProofKey, _read: Box<&mut dyn IdStore>) -> ProverResult<()> {
Ok(())
}

async fn aggregate(
_input: raiko_lib::input::AggregationGuestInput,
_output: &raiko_lib::input::AggregationGuestOutput,
_config: &ProverConfig,
_store: Option<&mut dyn IdWrite>,
) -> ProverResult<Proof> {
Ok(Proof {
..Default::default()
})
}
}

#[ignore = "Only used to test serialized data"]
Expand Down
Loading

0 comments on commit eb4d032

Please sign in to comment.