Skip to content

Commit

Permalink
fix: drop connection instead of manual close, enable deferred decrypt…
Browse files Browse the repository at this point in the history
…ion (#472)
  • Loading branch information
sinui0 authored Apr 9, 2024
1 parent b4334ad commit 68b9474
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions tlsn/examples/interactive/interactive.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use futures::AsyncWriteExt;
use http_body_util::Empty;
use hyper::{body::Bytes, Request, StatusCode, Uri};
use hyper_util::rt::TokioIo;
Expand Down Expand Up @@ -49,6 +48,8 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let server_port = uri.port_u16().unwrap_or(443);

// Create prover and connect to verifier.
//
// Perform the setup phase with the verifier.
let prover = Prover::new(
ProverConfig::builder()
.id(id)
Expand All @@ -64,9 +65,18 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let tls_client_socket = tokio::net::TcpStream::connect((server_domain, server_port))
.await
.unwrap();

// Pass server connection into the prover.
let (mpc_tls_connection, prover_fut) =
prover.connect(tls_client_socket.compat()).await.unwrap();

// Grab a controller for the Prover so we can enable deferred decryption.
let ctrl = prover_fut.control();

// Wrap the connection in a TokioIo compatibility layer to use it with hyper.
let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat());

// Spawn the Prover to run in the background.
let prover_task = tokio::spawn(prover_fut);

// MPC-TLS Handshake.
Expand All @@ -75,7 +85,12 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
.await
.unwrap();

let connection_task = tokio::spawn(connection.without_shutdown());
// Spawn the connection to run in the background.
tokio::spawn(connection);

// Enable deferred decryption. This speeds up the proving time, but doesn't
// let us see the decrypted data until after the connection is closed.
ctrl.defer_decryption().await.unwrap();

// MPC-TLS: Send Request and wait for Response.
let request = Request::builder()
Expand All @@ -90,10 +105,6 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(

assert!(response.status() == StatusCode::OK);

// Close TLS Connection.
let tls_connection = connection_task.await.unwrap().unwrap().io.into_inner();
tls_connection.compat().close().await.unwrap();

// Create proof for the Verifier.
let mut prover = prover_task.await.unwrap().unwrap().start_prove();
redact_and_reveal_received_data(&mut prover);
Expand Down Expand Up @@ -128,6 +139,7 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
response
.find("BEGIN PUBLIC KEY")
.expect("Expected valid public key in JSON response");

// Check Session info: server name.
assert_eq!(session_info.server_name.as_str(), SERVER_DOMAIN);

Expand Down

0 comments on commit 68b9474

Please sign in to comment.