Skip to content

Commit

Permalink
Use the latest tlsn-prover (#11)
Browse files Browse the repository at this point in the history
* Use the latest tlsn@8b163540 with patch

* Update rust-toolchain

forcing nightly

* lint

---------

Co-authored-by: Ryan MacArthur <[email protected]>
  • Loading branch information
mhchia and maceip authored Oct 6, 2023
1 parent ed65205 commit 2c1c634
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 58 deletions.
3 changes: 2 additions & 1 deletion rust-toolchain
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[toolchain]
channel = "nightly-2022-12-12"
# channel = "nightly-2022-12-12"
# channel = "stable"
# channel = "nightly-x86_64-apple-darwin"
channel = "nightly"
11 changes: 3 additions & 8 deletions wasm/prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ web-time = "0.2"
# tlsn-prover = { path = "../tlsn/tlsn/tlsn-prover", features = ["tracing"] }
[dependencies.tlsn-prover]
git = "https://github.com/mhchia/tlsn.git"
branch = "tlsn-examples-wasm"
branch = "dev-20230921-webtime"
package = "tlsn-prover"
features = ["tracing"]

Expand Down Expand Up @@ -74,16 +74,11 @@ branch="0.16.20-cleanup"
git="https://github.com/mhchia/ws_stream_wasm"
branch="dev"

[patch.'https://github.com/tlsnotary/tlsn-utils']
# [patch.'https://github.com/tlsnotary/tlsn-utils']
# # Use single cpu backend
# tlsn-utils = { git = 'https://www.github.com/mhchia/tlsn-utils.git', rev = "46327f0" }
# tlsn-utils-aio = { git = 'https://www.github.com/mhchia/tlsn-utils.git', rev = "46327f0" }

# Use older version of multi-threaded backend
tlsn-utils = { git = 'https://www.github.com/tlsnotary/tlsn-utils.git', rev = "f3e3f07" }
tlsn-utils-aio = { git = 'https://www.github.com/tlsnotary/tlsn-utils.git', rev = "f3e3f07" }


# The `console_error_panic_hook` crate provides better debugging of panics by
# logging them with `console.error`. This is great for development, but requires
# all the `std::fmt` and `std::panicking` infrastructure, so isn't great for
Expand All @@ -97,4 +92,4 @@ wasm-bindgen-test = "0.3.34"
# Tell `rustc` to optimize for small code size.

[package.metadata.wasm-pack.profile.release]
wasm-opt = false
wasm-opt = false
100 changes: 51 additions & 49 deletions wasm/prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use web_time::Instant;
use hyper::{body::to_bytes, Body, Request, StatusCode};
use futures::{AsyncWriteExt, TryFutureExt};
use futures::channel::oneshot;
use tlsn_prover::{bind_prover, ProverConfig};
use tlsn_prover::{Prover, ProverConfig};

// use tokio::io::AsyncWriteExt as _;
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
Expand Down Expand Up @@ -84,12 +84,14 @@ pub async fn prover() -> Result<(), JsValue> {
// let message = b"Hello from browser".to_vec();
// notary_ws_stream_into.write(&message).await
// .expect_throw( "Failed to write to websocket" );
// log!("!@# 0.1");

// let mut output = [0u8; 20];
// let bytes = notary_ws_stream_into.read(&mut output[..]).await.unwrap();
// assert_eq!(bytes, 18);
// log!("Received: {:?}", &output[..bytes]);


// Basic default prover config
let config = ProverConfig::builder()
.id("example")
Expand All @@ -99,32 +101,20 @@ pub async fn prover() -> Result<(), JsValue> {

log!("!@# 1");


log!("!@# 2");
let (tls_connection, prover_fut, mux_fut) =
bind_prover(config, client_ws_stream_into, notary_ws_stream_into)
// Create a Prover and set it up with the Notary
// This will set up the MPC backend prior to connecting to the server.
let prover = Prover::new(config)
.setup(notary_ws_stream_into)
.await
.unwrap();
log!("!@# 3");


// Spawn the Prover and Mux tasks to be run concurrently
// tokio::spawn(mux_fut);
let handled_mux_fut = async {
log!("!@# 4");
match mux_fut.await {
Ok(_) => {
log!("!@# 4.1");
()
},
Err(err) => {
panic!("An error occurred in mux_fut: {:?}", err);
}
}
};
log!("!@# 5");
spawn_local(handled_mux_fut);
log!("!@# 6");
// Bind the Prover to the server connection.
// The returned `mpc_tls_connection` is an MPC TLS connection to the Server: all data written
// to/read from it will be encrypted/decrypted using MPC with the Notary.
let (mpc_tls_connection, prover_fut) = prover.connect(client_ws_stream_into).await.unwrap();

log!("!@# 3");


// let prover_task = tokio::spawn(prover_fut);
Expand All @@ -140,12 +130,11 @@ pub async fn prover() -> Result<(), JsValue> {
}
}
};
// let prover_task = tokio::spawn(prover_fut);
spawn_local(handled_prover_fut);
log!("!@# 7");

// Attach the hyper HTTP client to the TLS connection
let (mut request_sender, connection) = hyper::client::conn::handshake(tls_connection.compat_write())
let (mut request_sender, connection) = hyper::client::conn::handshake(mpc_tls_connection.compat())
.await
.unwrap();
log!("!@# 8");
Expand Down Expand Up @@ -192,11 +181,12 @@ pub async fn prover() -> Result<(), JsValue> {
.unwrap();


log!("Sending request");
log!("Starting an MPC TLS connection with the server");

// Send the request to the Server and get a response via the MPC TLS connection
let response = request_sender.send_request(request).await.unwrap();

log!("Sent request");
log!("Got a response from the server");

assert!(response.status() == StatusCode::OK);

Expand All @@ -220,6 +210,7 @@ pub async fn prover() -> Result<(), JsValue> {
// The Prover task should be done now, so we can grab it.
// let mut prover = prover_task.await.unwrap().unwrap();
let mut prover = prover_receiver.await.unwrap();
let mut prover = prover.start_notarize();
log!("!@# 14");

// Identify the ranges in the transcript that contain secrets
Expand All @@ -233,29 +224,40 @@ pub async fn prover() -> Result<(), JsValue> {
);
log!("!@# 15");

// Commit to the outbound transcript, isolating the data that contain secrets
for range in public_ranges.iter().chain(private_ranges.iter()) {
prover.add_commitment_sent(range.clone()).unwrap();
}
log!("!@# 16");

// Commit to the full received transcript in one shot, as we don't need to redact anything
let recv_len = prover.recv_transcript().data().len();
log!("!@# 17");
prover.add_commitment_recv(0..recv_len as u32).unwrap();
log!("!@# 18");

let builder = prover.commitment_builder();

// Commit to each range of the public outbound data which we want to disclose
let sent_commitments: Vec<_> = public_ranges
.iter()
.map(|r| builder.commit_sent(r.clone()).unwrap())
.collect();

// Commit to all inbound data in one shot, as we don't need to redact anything in it
let recv_commitment = builder.commit_recv(0..recv_len).unwrap();

// Finalize, returning the notarized session
let notarized_session = prover.finalize().await.unwrap();
log!("!@# 19");

log!("Notarization complete!");
let res_str = serde_json::to_string_pretty(&notarized_session)
.unwrap();
log!("Notarized session: {}", res_str);
// Create a proof for all committed data in this session
let session_proof = notarized_session.session_proof();

let mut proof_builder = notarized_session.data().build_substrings_proof();

// Reveal all the public ranges
for commitment_id in sent_commitments {
proof_builder.reveal(commitment_id).unwrap();
}
proof_builder.reveal(recv_commitment).unwrap();

let substrings_proof = proof_builder.build().unwrap();
let res = serde_json::to_string_pretty(&(&session_proof, &substrings_proof, &SERVER_DOMAIN))
.unwrap();
log!("res = {}", res);

let duration = start_time.elapsed();
log!("!@# request costs: {} seconds", duration.as_secs());
log!("!@# request takes: {} seconds", duration.as_secs());

Ok(())

Expand All @@ -265,12 +267,12 @@ pub async fn prover() -> Result<(), JsValue> {
/// Find the ranges of the public and private parts of a sequence.
///
/// Returns a tuple of `(public, private)` ranges.
fn find_ranges(seq: &[u8], sub_seq: &[&[u8]]) -> (Vec<Range<u32>>, Vec<Range<u32>>) {
fn find_ranges(seq: &[u8], private_seq: &[&[u8]]) -> (Vec<Range<usize>>, Vec<Range<usize>>) {
let mut private_ranges = Vec::new();
for s in sub_seq {
for s in private_seq {
for (idx, w) in seq.windows(s.len()).enumerate() {
if w == *s {
private_ranges.push(idx as u32..(idx + w.len()) as u32);
private_ranges.push(idx..(idx + w.len()));
}
}
}
Expand All @@ -287,9 +289,9 @@ fn find_ranges(seq: &[u8], sub_seq: &[&[u8]]) -> (Vec<Range<u32>>, Vec<Range<u32
last_end = r.end;
}

if last_end < seq.len() as u32 {
public_ranges.push(last_end..seq.len() as u32);
if last_end < seq.len() {
public_ranges.push(last_end..seq.len());
}

(public_ranges, private_ranges)
}
}
41 changes: 41 additions & 0 deletions worker.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import * as Comlink from 'comlink';
// import init, { prover } from "./pkg/tlsn_extension_rs";
import init, { initThreadPool, prover } from './pkg/tlsn_extension_rs';

function hasSharedMemory() {
const hasSharedArrayBuffer = 'SharedArrayBuffer' in global;
const notCrossOriginIsolated = global.crossOriginIsolated === false;

return hasSharedArrayBuffer && !notCrossOriginIsolated;
}

const DATA = Array(20).fill(1);

class Test {
constructor() {
console.log('!@# test comlink');
this.test();
}

async test() {
console.log('start');
console.log('!@# hasSharedMemory=', hasSharedMemory());
const numConcurrency = navigator.hardwareConcurrency;
console.log('!@# numConcurrency=', numConcurrency);
const res = await init();
console.log('!@# res.memory=', res.memory);
// 6422528 ~= 6.12 mb
console.log('!@# res.memory.buffer.length=', res.memory.buffer.byteLength);
await initThreadPool(numConcurrency);
const resProver = await prover();
console.log('!@# resProver=', resProver);
console.log('!@# resAfter.memory=', res.memory);
// 1105920000 ~= 1.03 gb
console.log(
'!@# resAfter.memory.buffer.length=',
res.memory.buffer.byteLength,
);
}
}

Comlink.expose(Test);

0 comments on commit 2c1c634

Please sign in to comment.