From 2c1c634eee030d8577e726aa832427fc772fea38 Mon Sep 17 00:00:00 2001 From: Kevin Mai-Husan Chia Date: Fri, 6 Oct 2023 08:03:06 +0800 Subject: [PATCH] Use the latest tlsn-prover (#11) * Use the latest tlsn@8b163540 with patch * Update rust-toolchain forcing nightly * lint --------- Co-authored-by: Ryan MacArthur --- rust-toolchain | 3 +- wasm/prover/Cargo.toml | 11 ++--- wasm/prover/src/lib.rs | 100 +++++++++++++++++++++-------------------- worker.js | 41 +++++++++++++++++ 4 files changed, 97 insertions(+), 58 deletions(-) create mode 100644 worker.js diff --git a/rust-toolchain b/rust-toolchain index 16c0bd25..43ad591a 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1,4 +1,5 @@ [toolchain] -channel = "nightly-2022-12-12" +# channel = "nightly-2022-12-12" # channel = "stable" # channel = "nightly-x86_64-apple-darwin" +channel = "nightly" diff --git a/wasm/prover/Cargo.toml b/wasm/prover/Cargo.toml index fc34c2f6..29278f5c 100644 --- a/wasm/prover/Cargo.toml +++ b/wasm/prover/Cargo.toml @@ -37,7 +37,7 @@ web-time = "0.2" # tlsn-prover = { path = "../tlsn/tlsn/tlsn-prover", features = ["tracing"] } [dependencies.tlsn-prover] git = "https://github.com/mhchia/tlsn.git" -branch = "tlsn-examples-wasm" +branch = "dev-20230921-webtime" package = "tlsn-prover" features = ["tracing"] @@ -74,16 +74,11 @@ branch="0.16.20-cleanup" git="https://github.com/mhchia/ws_stream_wasm" branch="dev" -[patch.'https://github.com/tlsnotary/tlsn-utils'] +# [patch.'https://github.com/tlsnotary/tlsn-utils'] # # Use single cpu backend # tlsn-utils = { git = 'https://www.github.com/mhchia/tlsn-utils.git', rev = "46327f0" } # tlsn-utils-aio = { git = 'https://www.github.com/mhchia/tlsn-utils.git', rev = "46327f0" } -# Use older version of multi-threaded backend -tlsn-utils = { git = 'https://www.github.com/tlsnotary/tlsn-utils.git', rev = "f3e3f07" } -tlsn-utils-aio = { git = 'https://www.github.com/tlsnotary/tlsn-utils.git', rev = "f3e3f07" } - - # The `console_error_panic_hook` crate provides better debugging of panics by # logging them with `console.error`. This is great for development, but requires # all the `std::fmt` and `std::panicking` infrastructure, so isn't great for @@ -97,4 +92,4 @@ wasm-bindgen-test = "0.3.34" # Tell `rustc` to optimize for small code size. [package.metadata.wasm-pack.profile.release] -wasm-opt = false \ No newline at end of file +wasm-opt = false diff --git a/wasm/prover/src/lib.rs b/wasm/prover/src/lib.rs index 36bf1336..955ab73e 100644 --- a/wasm/prover/src/lib.rs +++ b/wasm/prover/src/lib.rs @@ -5,7 +5,7 @@ use web_time::Instant; use hyper::{body::to_bytes, Body, Request, StatusCode}; use futures::{AsyncWriteExt, TryFutureExt}; use futures::channel::oneshot; -use tlsn_prover::{bind_prover, ProverConfig}; +use tlsn_prover::{Prover, ProverConfig}; // use tokio::io::AsyncWriteExt as _; use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; @@ -84,12 +84,14 @@ pub async fn prover() -> Result<(), JsValue> { // let message = b"Hello from browser".to_vec(); // notary_ws_stream_into.write(&message).await // .expect_throw( "Failed to write to websocket" ); + // log!("!@# 0.1"); // let mut output = [0u8; 20]; // let bytes = notary_ws_stream_into.read(&mut output[..]).await.unwrap(); // assert_eq!(bytes, 18); // log!("Received: {:?}", &output[..bytes]); + // Basic default prover config let config = ProverConfig::builder() .id("example") @@ -99,32 +101,20 @@ pub async fn prover() -> Result<(), JsValue> { log!("!@# 1"); - - log!("!@# 2"); - let (tls_connection, prover_fut, mux_fut) = - bind_prover(config, client_ws_stream_into, notary_ws_stream_into) + // Create a Prover and set it up with the Notary + // This will set up the MPC backend prior to connecting to the server. + let prover = Prover::new(config) + .setup(notary_ws_stream_into) .await .unwrap(); - log!("!@# 3"); - // Spawn the Prover and Mux tasks to be run concurrently - // tokio::spawn(mux_fut); - let handled_mux_fut = async { - log!("!@# 4"); - match mux_fut.await { - Ok(_) => { - log!("!@# 4.1"); - () - }, - Err(err) => { - panic!("An error occurred in mux_fut: {:?}", err); - } - } - }; - log!("!@# 5"); - spawn_local(handled_mux_fut); - log!("!@# 6"); + // Bind the Prover to the server connection. + // The returned `mpc_tls_connection` is an MPC TLS connection to the Server: all data written + // to/read from it will be encrypted/decrypted using MPC with the Notary. + let (mpc_tls_connection, prover_fut) = prover.connect(client_ws_stream_into).await.unwrap(); + + log!("!@# 3"); // let prover_task = tokio::spawn(prover_fut); @@ -140,12 +130,11 @@ pub async fn prover() -> Result<(), JsValue> { } } }; - // let prover_task = tokio::spawn(prover_fut); spawn_local(handled_prover_fut); log!("!@# 7"); // Attach the hyper HTTP client to the TLS connection - let (mut request_sender, connection) = hyper::client::conn::handshake(tls_connection.compat_write()) + let (mut request_sender, connection) = hyper::client::conn::handshake(mpc_tls_connection.compat()) .await .unwrap(); log!("!@# 8"); @@ -192,11 +181,12 @@ pub async fn prover() -> Result<(), JsValue> { .unwrap(); - log!("Sending request"); + log!("Starting an MPC TLS connection with the server"); + // Send the request to the Server and get a response via the MPC TLS connection let response = request_sender.send_request(request).await.unwrap(); - log!("Sent request"); + log!("Got a response from the server"); assert!(response.status() == StatusCode::OK); @@ -220,6 +210,7 @@ pub async fn prover() -> Result<(), JsValue> { // The Prover task should be done now, so we can grab it. // let mut prover = prover_task.await.unwrap().unwrap(); let mut prover = prover_receiver.await.unwrap(); + let mut prover = prover.start_notarize(); log!("!@# 14"); // Identify the ranges in the transcript that contain secrets @@ -233,29 +224,40 @@ pub async fn prover() -> Result<(), JsValue> { ); log!("!@# 15"); - // Commit to the outbound transcript, isolating the data that contain secrets - for range in public_ranges.iter().chain(private_ranges.iter()) { - prover.add_commitment_sent(range.clone()).unwrap(); - } - log!("!@# 16"); - - // Commit to the full received transcript in one shot, as we don't need to redact anything let recv_len = prover.recv_transcript().data().len(); - log!("!@# 17"); - prover.add_commitment_recv(0..recv_len as u32).unwrap(); - log!("!@# 18"); + + let builder = prover.commitment_builder(); + + // Commit to each range of the public outbound data which we want to disclose + let sent_commitments: Vec<_> = public_ranges + .iter() + .map(|r| builder.commit_sent(r.clone()).unwrap()) + .collect(); + + // Commit to all inbound data in one shot, as we don't need to redact anything in it + let recv_commitment = builder.commit_recv(0..recv_len).unwrap(); // Finalize, returning the notarized session let notarized_session = prover.finalize().await.unwrap(); - log!("!@# 19"); - log!("Notarization complete!"); - let res_str = serde_json::to_string_pretty(¬arized_session) - .unwrap(); - log!("Notarized session: {}", res_str); + // Create a proof for all committed data in this session + let session_proof = notarized_session.session_proof(); + + let mut proof_builder = notarized_session.data().build_substrings_proof(); + + // Reveal all the public ranges + for commitment_id in sent_commitments { + proof_builder.reveal(commitment_id).unwrap(); + } + proof_builder.reveal(recv_commitment).unwrap(); + + let substrings_proof = proof_builder.build().unwrap(); + let res = serde_json::to_string_pretty(&(&session_proof, &substrings_proof, &SERVER_DOMAIN)) + .unwrap(); + log!("res = {}", res); let duration = start_time.elapsed(); - log!("!@# request costs: {} seconds", duration.as_secs()); + log!("!@# request takes: {} seconds", duration.as_secs()); Ok(()) @@ -265,12 +267,12 @@ pub async fn prover() -> Result<(), JsValue> { /// Find the ranges of the public and private parts of a sequence. /// /// Returns a tuple of `(public, private)` ranges. -fn find_ranges(seq: &[u8], sub_seq: &[&[u8]]) -> (Vec>, Vec>) { +fn find_ranges(seq: &[u8], private_seq: &[&[u8]]) -> (Vec>, Vec>) { let mut private_ranges = Vec::new(); - for s in sub_seq { + for s in private_seq { for (idx, w) in seq.windows(s.len()).enumerate() { if w == *s { - private_ranges.push(idx as u32..(idx + w.len()) as u32); + private_ranges.push(idx..(idx + w.len())); } } } @@ -287,9 +289,9 @@ fn find_ranges(seq: &[u8], sub_seq: &[&[u8]]) -> (Vec>, Vec