Skip to content

Commit

Permalink
Verifier Optimization (#9)
Browse files Browse the repository at this point in the history
* make s_i scalars

* why not working

* batch inv

* compilation, tests suspiciously all pass

* alright, convinced

* comments, clippy

* comments, clippy

* clean up

* clean up
  • Loading branch information
jkwoods authored Dec 7, 2023
1 parent e107eac commit 1a15695
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 72 deletions.
213 changes: 146 additions & 67 deletions src/provider/ipa_pc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,49 +508,6 @@ where
))
}

fn bullet_reduce_verifier(
P: &Commitment<G>,
P_L: &Commitment<G>,
P_R: &Commitment<G>,
a_vec: &[G::Scalar],
gens: &CommitmentGens<G>,
transcript: &mut Transcript,
) -> Result<
(
Commitment<G>, // P'
Vec<G::Scalar>, // a_vec'
CommitmentGens<G>, // gens'
),
NovaError,
> {
let n = a_vec.len();

P_L.append_to_transcript(b"L", transcript);
P_R.append_to_transcript(b"R", transcript);

let chal = G::Scalar::challenge(b"challenge_r", transcript);

// println!("Challenge in bullet_reduce_verifier {:?}", chal);

let chal_square = chal * chal;
let chal_inverse = chal.invert().unwrap();
let chal_inverse_square = chal_inverse * chal_inverse;

// This takes care of splitting them in half and multiplying left half
// by chal_inverse and right half by chal
let gens_prime = gens.fold(&chal_inverse, &chal);

let a_vec_prime = a_vec[0..n / 2]
.par_iter()
.zip(a_vec[n / 2..n].par_iter())
.map(|(a_L, a_R)| *a_L * chal_inverse + chal * *a_R)
.collect::<Vec<G::Scalar>>();

let P_prime = *P_L * chal_square + *P + *P_R * chal_inverse_square;

Ok((P_prime, a_vec_prime, gens_prime))
}

/// prover inner product argument
pub fn prove(
gens: &CommitmentGens<G>,
Expand Down Expand Up @@ -661,6 +618,107 @@ where
})
}

// from Spartan, notably without the zeroizing buffer
fn batch_invert(inputs: &mut [G::Scalar]) -> G::Scalar {
// This code is essentially identical to the FieldElement
// implementation, and is documented there. Unfortunately,
// it's not easy to write it generically, since here we want
// to use `UnpackedScalar`s internally, and `Scalar`s
// externally, but there's no corresponding distinction for
// field elements.

let n = inputs.len();
let one = G::Scalar::one();

// Place scratch storage in a Zeroizing wrapper to wipe it when
// we pass out of scope.
let mut scratch = vec![one; n];
//let mut scratch = Zeroizing::new(scratch_vec);

// Keep an accumulator of all of the previous products
let mut acc = G::Scalar::one();

// Pass through the input vector, recording the previous
// products in the scratch space
for (input, scratch) in inputs.iter().zip(scratch.iter_mut()) {
*scratch = acc;

acc *= input;
}

// acc is nonzero iff all inputs are nonzero
debug_assert!(acc != G::Scalar::zero());

// Compute the inverse of all products
acc = acc.invert().unwrap();

// We need to return the product of all inverses later
let ret = acc;

// Pass through the vector backwards to compute the inverses
// in place
for (input, scratch) in inputs.iter_mut().rev().zip(scratch.iter().rev()) {
let tmp = acc * *input;
*input = acc * scratch;
acc = tmp;
}

ret
}

// copied almost directly from the Spartan method, with some type massaging
fn verification_scalars(
&self,
n: usize,
transcript: &mut Transcript,
) -> Result<(Vec<G::Scalar>, Vec<G::Scalar>, Vec<G::Scalar>), NovaError> {
let lg_n = self.P_L_vec.len();
if lg_n >= 32 {
// 4 billion multiplications should be enough for anyone
// and this check prevents overflow in 1<<lg_n below.
return Err(NovaError::ProofVerifyError);
}
if n != (1 << lg_n) {
return Err(NovaError::ProofVerifyError);
}

let mut challenges = Vec::with_capacity(lg_n);

// Recompute x_k,...,x_1 based on the proof transcript
for (P_L, P_R) in self.P_L_vec.iter().zip(self.P_R_vec.iter()) {
P_L.append_to_transcript(b"L", transcript);
P_R.append_to_transcript(b"R", transcript);

challenges.push(G::Scalar::challenge(b"challenge_r", transcript));
}

// inverses
let mut challenges_inv = challenges.clone();
let prod_all_inv = Self::batch_invert(&mut challenges_inv);

// squares of challenges & inverses
for i in 0..lg_n {
challenges[i] = challenges[i].square();
challenges_inv[i] = challenges_inv[i].square();
}
let challenges_sq = challenges;
let challenges_inv_sq = challenges_inv;

// s values inductively
let mut s = Vec::with_capacity(n);
s.push(prod_all_inv);
for i in 1..n {
let lg_i = (32 - 1 - (i as u32).leading_zeros()) as usize;
let k = 1 << lg_i;
// The challenges are stored in "creation order" as [u_k,...,u_1],
// so u_{lg(i)+1} = is indexed by (lg_n-1) - lg_i
let u_lg_i_sq = challenges_sq[(lg_n - 1) - lg_i];
s.push(s[i - k] * u_lg_i_sq);
}

Ok((challenges_sq, challenges_inv_sq, s))
}

/// verify inner product argument
pub fn verify(
&self,
Expand All @@ -686,41 +744,62 @@ where
// Scaling to be compatible with Bulletproofs figure 1
let chal = G::Scalar::challenge(b"r", transcript); // sample a random challenge for scaling commitment
let gens_y = gens_y.scale(&chal);
let mut P = U.comm_x_vec + U.comm_y * chal;
let P = U.comm_x_vec + U.comm_y * chal;

let mut gens = gens.clone();
let mut a_vec = U.a_vec.clone();
let a_vec = U.a_vec.clone();

// Step 1 in Hyrax's figure 7.
for i in 0..self.P_L_vec.len() {
let P_L = self.P_L_vec[i].decompress().unwrap();
let P_R = self.P_R_vec[i].decompress().unwrap();
// calculate all the exponent challenges (s) and inverses at once
let (mut u_sq, mut u_inv_sq, s) = self.verification_scalars(n, transcript)?;

let (P_prime, a_vec_prime, gens_prime) =
Self::bullet_reduce_verifier(&P, &P_L, &P_R, &a_vec, &gens, transcript)?;
// do all the exponentiations at once (Hyrax, Fig. 7, step 4, all rounds)
let g_hat = G::vartime_multiscalar_mul(&s, &gens.get_gens());
let a = inner_product(&a_vec[..], &s[..]);

P = P_prime;
a_vec = a_vec_prime;
gens = gens_prime;
}
let mut Ls: Vec<G::PreprocessedGroupElement> = self
.P_L_vec
.iter()
.map(|p| p.decompress().unwrap().reinterpret_as_generator())
.collect();
let mut Rs: Vec<G::PreprocessedGroupElement> = self
.P_R_vec
.iter()
.map(|p| p.decompress().unwrap().reinterpret_as_generator())
.collect();

Ls.append(&mut Rs);
Ls.push(P.reinterpret_as_generator());

u_sq.append(&mut u_inv_sq);
u_sq.push(G::Scalar::one());

// Step 3 in Hyrax's Figure 7
let P_comm = G::vartime_multiscalar_mul(&u_sq, &Ls[..]);

// Step 3 in Hyrax's Figure 8
self.beta.append_to_transcript(b"beta", transcript);
self.delta.append_to_transcript(b"delta", transcript);

let chal = G::Scalar::challenge(b"chal_z", transcript);

let P_plus_beta = P * chal + self.beta.decompress().unwrap();
let P_plus_beta_to_a = P_plus_beta * a_vec[0];
let left_hand_side = P_plus_beta_to_a + self.delta.decompress().unwrap();

let g_hat = CE::<G>::commit(&gens, &[G::Scalar::one()], &G::Scalar::zero());
let g_to_a = CE::<G>::commit(&gens_y, &a_vec, &G::Scalar::zero()); // g^a*h^0 = g^a
let h_to_z2 = CE::<G>::commit(&gens_y, &[G::Scalar::zero()], &self.z_2); // g^0 * h^z2 = h^z2
// Step 5 in Hyrax's Figure 8
// P^(chal*a) * beta^a * delta^1
let left_hand_side = G::vartime_multiscalar_mul(
&[(chal * a), a, G::Scalar::one()],
&[
P_comm.preprocessed(),
self.beta.decompress().unwrap().reinterpret_as_generator(),
self.delta.decompress().unwrap().reinterpret_as_generator(),
],
);

let g_hat_plus_g_to_a = g_hat + g_to_a;
let val_to_z1 = g_hat_plus_g_to_a * self.z_1;
let right_hand_side = val_to_z1 + h_to_z2;
// g_hat^z1 * g^(a*z1) * h^z2
let right_hand_side = G::vartime_multiscalar_mul(
&[self.z_1, (self.z_1 * a), self.z_2],
&[
g_hat.preprocessed(),
gens_y.get_gens()[0].clone(),
gens_y.get_blinding_gen(),
],
);

if left_hand_side == right_hand_side {
Ok(())
Expand Down
7 changes: 2 additions & 5 deletions src/spartan/direct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ mod tests {

// verify the SNARK
let z_out = circuit.output(&z_0);
let io = z_0.into_iter().chain(z_out.into_iter()).collect::<Vec<_>>();
let io = z_0.into_iter().chain(z_out).collect::<Vec<_>>();
let res = snark.cap_verify(&vk, &io, &com_v);
assert!(res.is_ok());
}
Expand Down Expand Up @@ -542,10 +542,7 @@ mod tests {
let snark = res.unwrap();

// verify the SNARK
let io = input
.into_iter()
.chain(output.clone().into_iter())
.collect::<Vec<_>>();
let io = input.into_iter().chain(output.clone()).collect::<Vec<_>>();
let res = snark.verify(&vk, &io);
assert!(res.is_ok());

Expand Down

0 comments on commit 1a15695

Please sign in to comment.