Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(raiko): merge stress test upgrades #392

Merged
merged 12 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
640 changes: 306 additions & 334 deletions Cargo.lock

Large diffs are not rendered by default.

9 changes: 4 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,12 @@ reth-chainspec = { git = "https://github.com/taikoxyz/taiko-reth.git", branch =
reth-provider = { git = "https://github.com/taikoxyz/taiko-reth.git", branch = "v1.0.0-rc.2-taiko", default-features = false }

# risc zero
risc0-zkvm = { version = "1.0.1", features = ["prove", "getrandom"] }
bonsai-sdk = { version = "0.8.0", features = ["async"] }
risc0-build = { version = "1.0.1" }
risc0-binfmt = { version = "1.0.1" }
risc0-zkvm = { version = "=1.1.2", features = ["prove", "getrandom"] }
bonsai-sdk = { version = "=1.1.2" }
risc0-binfmt = { version = "=1.1.2" }

# SP1
sp1-sdk = { version = "2.0.0" }
sp1-sdk = { version = "=3.0.0-rc3" }
sp1-zkvm = { version = "2.0.0" }
sp1-helper = { version = "2.0.0" }

Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.zk
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ RUN echo "Building for sp1"
ENV TARGET=sp1
RUN make install
RUN make guest
RUN cargo build --release ${BUILD_FLAGS} --features "sp1,risc0,bonsai-auto-scaling" --features "docker_build"
RUN cargo build --release ${BUILD_FLAGS} --features "sp1,risc0" --features "docker_build"

RUN mkdir -p \
./bin \
Expand Down
7 changes: 5 additions & 2 deletions docker/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,11 @@ services:
volumes:
- /var/log/raiko:/var/log/raiko
ports:
- "8081:8080"
- "8080:8080"
environment:
# you can use your own PCCS host
# - PCCS_HOST=host.docker.internal:8081
- RUST_LOG=${RUST_LOG:-info}
- ZK=true
- ETHEREUM_RPC=${ETHEREUM_RPC}
- ETHEREUM_BEACON_RPC=${ETHEREUM_BEACON_RPC}
Expand All @@ -145,11 +146,13 @@ services:
- NETWORK=${NETWORK}
- BONSAI_API_KEY=${BONSAI_API_KEY}
- BONSAI_API_URL=${BONSAI_API_URL}
- MAX_BONSAI_GPU_NUM=15
- MAX_BONSAI_GPU_NUM=300
- GROTH16_VERIFIER_RPC_URL=${GROTH16_VERIFIER_RPC_URL}
- GROTH16_VERIFIER_ADDRESS=${GROTH16_VERIFIER_ADDRESS}
- SP1_PRIVATE_KEY=${SP1_PRIVATE_KEY}
- SKIP_SIMULATION=true
- SP1_VERIFIER_RPC_URL=${SP1_VERIFIER_RPC_URL}
- SP1_VERIFIER_ADDRESS=${SP1_VERIFIER_ADDRESS}
pccs:
build:
context: ..
Expand Down
11 changes: 7 additions & 4 deletions host/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,18 @@ impl ProofActor {
pub async fn run_task(&mut self, proof_request: ProofRequest) {
let cancel_token = CancellationToken::new();

let Ok((chain_id, blockhash)) = get_task_data(
let (chain_id, blockhash) = match get_task_data(
&proof_request.network,
proof_request.block_number,
&self.chain_specs,
)
.await
else {
error!("Could not get task data for {proof_request:?}");
return;
{
Ok(v) => v,
Err(e) => {
error!("Could not get task data for {proof_request:?}, error: {e}");
return;
}
};

let key = TaskDescriptor::from((
Expand Down
60 changes: 39 additions & 21 deletions provers/risc0/driver/src/bonsai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
Risc0Response,
};
use alloy_primitives::B256;
use bonsai_sdk::blocking::{Client, SessionId};
use log::{debug, error, info, warn};
use raiko_lib::{
primitives::keccak::keccak,
Expand All @@ -19,14 +20,17 @@ use std::{
fs,
path::{Path, PathBuf},
};
use tokio::time::{sleep as tokio_async_sleep, Duration};

use crate::Risc0Param;

const MAX_REQUEST_RETRY: usize = 8;

#[derive(thiserror::Error, Debug)]
pub enum BonsaiExecutionError {
// common errors: include sdk error, or some others from non-bonsai code
#[error(transparent)]
SdkFailure(#[from] bonsai_sdk::alpha::SdkErr),
SdkFailure(#[from] bonsai_sdk::SdkErr),
#[error("bonsai execution error: {0}")]
Other(String),
// critical error like OOM, which is un-recoverable
Expand All @@ -44,12 +48,12 @@ pub async fn verify_bonsai_receipt<O: Eq + Debug + DeserializeOwned>(
max_retries: usize,
) -> Result<(String, Receipt), BonsaiExecutionError> {
info!("Tracking receipt uuid: {uuid}");
let session = bonsai_sdk::alpha::SessionId { uuid };
let session = SessionId { uuid };

loop {
let mut res = None;
for attempt in 1..=max_retries {
let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?;
let client = Client::from_env(risc0_zkvm::VERSION)?;

match session.status(&client) {
Ok(response) => {
Expand All @@ -61,7 +65,7 @@ pub async fn verify_bonsai_receipt<O: Eq + Debug + DeserializeOwned>(
return Err(BonsaiExecutionError::SdkFailure(err));
}
warn!("Attempt {attempt}/{max_retries} for session status request: {err:?}");
std::thread::sleep(std::time::Duration::from_secs(15));
tokio_async_sleep(Duration::from_secs(15)).await;
continue;
}
}
Expand All @@ -72,17 +76,18 @@ pub async fn verify_bonsai_receipt<O: Eq + Debug + DeserializeOwned>(

if res.status == "RUNNING" {
info!(
"Current status: {} - state: {} - continue polling...",
"Current {session:?} status: {} - state: {} - continue polling...",
res.status,
res.state.unwrap_or_default()
);
std::thread::sleep(std::time::Duration::from_secs(15));
tokio_async_sleep(Duration::from_secs(15)).await;
} else if res.status == "SUCCEEDED" {
// Download the receipt, containing the output
info!("Prove task {session:?} success.");
let receipt_url = res
.receipt_url
.expect("API error, missing receipt on completed session");
let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?;
let client = Client::from_env(risc0_zkvm::VERSION)?;
let receipt_buf = client.download(&receipt_url)?;
let receipt: Receipt = bincode::deserialize(&receipt_buf).map_err(|e| {
BonsaiExecutionError::Other(format!("Failed to deserialize receipt: {e:?}"))
Expand All @@ -104,10 +109,10 @@ pub async fn verify_bonsai_receipt<O: Eq + Debug + DeserializeOwned>(
}
return Ok((session.uuid, receipt));
} else {
let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?;
let client = Client::from_env(risc0_zkvm::VERSION)?;
let bonsai_err_log = session.logs(&client);
return Err(BonsaiExecutionError::Fatal(format!(
"Workflow exited: {} - | err: {} | log: {bonsai_err_log:?}",
"Workflow {session:?} exited: {} - | err: {} | log: {bonsai_err_log:?}",
res.status,
res.error_msg.unwrap_or_default(),
)));
Expand Down Expand Up @@ -167,11 +172,11 @@ pub async fn maybe_prove<I: Serialize, O: Eq + Debug + Serialize + DeserializeOw
}
Err(BonsaiExecutionError::SdkFailure(err)) => {
warn!("Bonsai SDK fail: {err:?}, keep tracking...");
std::thread::sleep(std::time::Duration::from_secs(15));
tokio_async_sleep(Duration::from_secs(15)).await;
}
Err(BonsaiExecutionError::Other(err)) => {
warn!("Something wrong: {err:?}, keep tracking...");
std::thread::sleep(std::time::Duration::from_secs(15));
tokio_async_sleep(Duration::from_secs(15)).await;
}
Err(BonsaiExecutionError::Fatal(err)) => {
error!("Fatal error on Bonsai: {err:?}");
Expand Down Expand Up @@ -228,13 +233,13 @@ pub async fn maybe_prove<I: Serialize, O: Eq + Debug + Serialize + DeserializeOw
}

pub async fn upload_receipt(receipt: &Receipt) -> anyhow::Result<String> {
let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?;
let client = Client::from_env(risc0_zkvm::VERSION)?;
Ok(client.upload_receipt(bincode::serialize(receipt)?)?)
}

pub async fn cancel_proof(uuid: String) -> anyhow::Result<()> {
let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?;
let session = bonsai_sdk::alpha::SessionId { uuid };
let client = Client::from_env(risc0_zkvm::VERSION)?;
let session = SessionId { uuid };
session.stop(&client)?;
#[cfg(feature = "bonsai-auto-scaling")]
auto_scaling::shutdown_bonsai().await?;
Expand All @@ -257,7 +262,7 @@ pub async fn prove_bonsai<O: Eq + Debug + DeserializeOwned>(
// Prepare input data
let input_data = bytemuck::cast_slice(&encoded_input).to_vec();

let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?;
let client = Client::from_env(risc0_zkvm::VERSION)?;
client.upload_img(&encoded_image_id, elf.to_vec())?;
// upload input
let input_id = client.upload_input(input_data.clone())?;
Expand All @@ -266,6 +271,7 @@ pub async fn prove_bonsai<O: Eq + Debug + DeserializeOwned>(
encoded_image_id.clone(),
input_id.clone(),
assumption_uuids.clone(),
false,
)?;

if let Some(id_store) = id_store {
Expand All @@ -277,7 +283,13 @@ pub async fn prove_bonsai<O: Eq + Debug + DeserializeOwned>(
})?;
}

verify_bonsai_receipt(image_id, expected_output, session.uuid.clone(), 8).await
verify_bonsai_receipt(
image_id,
expected_output,
session.uuid.clone(),
MAX_REQUEST_RETRY,
)
.await
}

pub async fn bonsai_stark_to_snark(
Expand All @@ -286,10 +298,14 @@ pub async fn bonsai_stark_to_snark(
input: B256,
) -> ProverResult<Risc0Response> {
let image_id = Digest::from(RISC0_GUEST_ID);
let (snark_uuid, snark_receipt) =
stark2snark(image_id, stark_uuid.clone(), stark_receipt.clone())
.await
.map_err(|err| format!("Failed to convert STARK to SNARK: {err:?}"))?;
let (snark_uuid, snark_receipt) = stark2snark(
image_id,
stark_uuid.clone(),
stark_receipt.clone(),
MAX_REQUEST_RETRY,
)
.await
.map_err(|err| format!("Failed to convert STARK to SNARK: {err:?}"))?;

info!("Validating SNARK uuid: {snark_uuid}");

Expand Down Expand Up @@ -382,8 +398,10 @@ pub fn load_receipt<T: serde::de::DeserializeOwned>(

pub fn save_receipt<T: serde::Serialize>(receipt_label: &String, receipt_data: &(String, T)) {
if !is_dev_mode() {
let cache_path = zkp_cache_path(receipt_label);
info!("Saving receipt to cache: {cache_path:?}");
fs::write(
zkp_cache_path(receipt_label),
cache_path,
bincode::serialize(receipt_data).expect("Failed to serialize receipt!"),
)
.expect("Failed to save receipt output file.");
Expand Down
2 changes: 1 addition & 1 deletion provers/risc0/driver/src/methods/risc0_aggregation.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pub const RISC0_AGGREGATION_ELF: &[u8] =
include_bytes!("../../../guest/target/riscv32im-risc0-zkvm-elf/release/risc0-aggregation");
pub const RISC0_AGGREGATION_ID: [u32; 8] = [
3593026424, 359928015, 3488866833, 2676323972, 1129344711, 55769507, 233041442, 3293280986,
3190692238, 1991537256, 2457220677, 1764592515, 1585399420, 97928005, 276688816, 447831862,
];
2 changes: 1 addition & 1 deletion provers/risc0/driver/src/methods/risc0_guest.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pub const RISC0_GUEST_ELF: &[u8] =
include_bytes!("../../../guest/target/riscv32im-risc0-zkvm-elf/release/risc0-guest");
pub const RISC0_GUEST_ID: [u32; 8] = [
2522428380, 1790994278, 397707036, 244564411, 3780865207, 1282154214, 1673205005, 3172292887,
3473581204, 2561439051, 2320161003, 3018340632, 1481329104, 1608433297, 3314099706, 2669934765,
];
47 changes: 34 additions & 13 deletions provers/risc0/driver/src/snarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::{str::FromStr, sync::Arc};
use alloy_primitives::B256;
use alloy_sol_types::{sol, SolValue};
use anyhow::Result;
use bonsai_sdk::alpha::responses::SnarkReceipt;
use bonsai_sdk::blocking::Client;
use ethers_contract::abigen;
use ethers_core::types::H160;
use ethers_providers::{Http, Provider, RetryClient};
Expand All @@ -27,6 +27,7 @@ use risc0_zkvm::{
sha::{Digest, Digestible},
Groth16ReceiptVerifierParameters, Receipt,
};
use tokio::time::{sleep as tokio_async_sleep, Duration};

use tracing::{error as tracing_err, info as tracing_info};

Expand Down Expand Up @@ -86,7 +87,8 @@ pub async fn stark2snark(
image_id: Digest,
stark_uuid: String,
stark_receipt: Receipt,
) -> Result<(String, SnarkReceipt)> {
max_retries: usize,
) -> Result<(String, Receipt)> {
info!("Submitting SNARK workload");
// Label snark output as journal digest
let receipt_label = format!(
Expand All @@ -106,20 +108,38 @@ pub async fn stark2snark(
stark_uuid
};

let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?;
let snark_uuid = client.create_snark(stark_uuid)?;
let client = Client::from_env(risc0_zkvm::VERSION)?;
let snark_uuid = client.create_snark(stark_uuid.clone())?;

let mut retry = 0;
let snark_receipt = loop {
let res = snark_uuid.status(&client)?;

if res.status == "RUNNING" {
info!("Current status: {} - continue polling...", res.status);
std::thread::sleep(std::time::Duration::from_secs(15));
info!(
"Current {:?} status: {} - continue polling...",
&stark_uuid, res.status
);
tokio_async_sleep(Duration::from_secs(15)).await;
} else if res.status == "SUCCEEDED" {
break res
let download_url = res
.output
.expect("Bonsai response is missing SnarkReceipt.");
let receipt_buf = client.download(&download_url)?;
let snark_receipt: Receipt = bincode::deserialize(&receipt_buf)?;
break snark_receipt;
} else {
if retry < max_retries {
retry += 1;
info!(
"Workflow {:?} exited: {} - | err: {} - retrying {}/{max_retries}",
stark_uuid,
res.status,
res.error_msg.unwrap_or_default(),
retry
);
tokio_async_sleep(Duration::from_secs(15)).await;
continue;
}
panic!(
"Workflow exited: {} - | err: {}",
res.status,
Expand All @@ -129,15 +149,15 @@ pub async fn stark2snark(
};

let stark_psd = stark_receipt.claim()?.as_value().unwrap().post.digest();
let snark_psd = Digest::try_from(snark_receipt.post_state_digest.as_slice())?;
let snark_psd = snark_receipt.claim()?.as_value().unwrap().post.digest();

if stark_psd != snark_psd {
error!("SNARK/STARK Post State Digest mismatch!");
error!("STARK: {}", hex::encode(stark_psd));
error!("SNARK: {}", hex::encode(snark_psd));
}

if snark_receipt.journal != stark_receipt.journal.bytes {
if snark_receipt.journal.bytes != stark_receipt.journal.bytes {
error!("SNARK/STARK Receipt Journal mismatch!");
error!("STARK: {}", hex::encode(&stark_receipt.journal.bytes));
error!("SNARK: {}", hex::encode(&snark_receipt.journal));
Expand All @@ -152,11 +172,12 @@ pub async fn stark2snark(

pub async fn verify_groth16_from_snark_receipt(
image_id: Digest,
snark_receipt: SnarkReceipt,
snark_receipt: Receipt,
) -> Result<Vec<u8>> {
let seal = encode(snark_receipt.snark.to_vec())?;
let groth16_claim = snark_receipt.inner.groth16().unwrap();
let seal = groth16_claim.seal.clone();
let journal_digest = snark_receipt.journal.digest();
let post_state_digest = snark_receipt.post_state_digest.digest();
let post_state_digest = snark_receipt.claim()?.as_value().unwrap().post.digest();
let encoded_proof =
verify_groth16_snark_impl(image_id, seal, journal_digest, post_state_digest).await?;
let proof = (encoded_proof, B256::from_slice(image_id.as_bytes()))
Expand Down
Loading
Loading