Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Verifier Optimization #9

Merged
merged 9 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 146 additions & 67 deletions src/provider/ipa_pc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,49 +508,6 @@ where
))
}

fn bullet_reduce_verifier(
P: &Commitment<G>,
P_L: &Commitment<G>,
P_R: &Commitment<G>,
a_vec: &[G::Scalar],
gens: &CommitmentGens<G>,
transcript: &mut Transcript,
) -> Result<
(
Commitment<G>, // P'
Vec<G::Scalar>, // a_vec'
CommitmentGens<G>, // gens'
),
NovaError,
> {
let n = a_vec.len();

P_L.append_to_transcript(b"L", transcript);
P_R.append_to_transcript(b"R", transcript);

let chal = G::Scalar::challenge(b"challenge_r", transcript);

// println!("Challenge in bullet_reduce_verifier {:?}", chal);

let chal_square = chal * chal;
let chal_inverse = chal.invert().unwrap();
let chal_inverse_square = chal_inverse * chal_inverse;

// This takes care of splitting them in half and multiplying left half
// by chal_inverse and right half by chal
let gens_prime = gens.fold(&chal_inverse, &chal);

let a_vec_prime = a_vec[0..n / 2]
.par_iter()
.zip(a_vec[n / 2..n].par_iter())
.map(|(a_L, a_R)| *a_L * chal_inverse + chal * *a_R)
.collect::<Vec<G::Scalar>>();

let P_prime = *P_L * chal_square + *P + *P_R * chal_inverse_square;

Ok((P_prime, a_vec_prime, gens_prime))
}

/// prover inner product argument
pub fn prove(
gens: &CommitmentGens<G>,
Expand Down Expand Up @@ -661,6 +618,107 @@ where
})
}

// from Spartan, notably without the zeroizing buffer
fn batch_invert(inputs: &mut [G::Scalar]) -> G::Scalar {
// This code is essentially identical to the FieldElement
// implementation, and is documented there. Unfortunately,
// it's not easy to write it generically, since here we want
// to use `UnpackedScalar`s internally, and `Scalar`s
// externally, but there's no corresponding distinction for
// field elements.

let n = inputs.len();
let one = G::Scalar::one();

// Place scratch storage in a Zeroizing wrapper to wipe it when
// we pass out of scope.
let mut scratch = vec![one; n];
//let mut scratch = Zeroizing::new(scratch_vec);

// Keep an accumulator of all of the previous products
let mut acc = G::Scalar::one();

// Pass through the input vector, recording the previous
// products in the scratch space
for (input, scratch) in inputs.iter().zip(scratch.iter_mut()) {
*scratch = acc;

acc *= input;
}

// acc is nonzero iff all inputs are nonzero
debug_assert!(acc != G::Scalar::zero());

// Compute the inverse of all products
acc = acc.invert().unwrap();

// We need to return the product of all inverses later
let ret = acc;

// Pass through the vector backwards to compute the inverses
// in place
for (input, scratch) in inputs.iter_mut().rev().zip(scratch.iter().rev()) {
let tmp = acc * *input;
*input = acc * scratch;
acc = tmp;
}

ret
}

// copied almost directly from the Spartan method, with some type massaging
fn verification_scalars(
&self,
n: usize,
transcript: &mut Transcript,
) -> Result<(Vec<G::Scalar>, Vec<G::Scalar>, Vec<G::Scalar>), NovaError> {
let lg_n = self.P_L_vec.len();
if lg_n >= 32 {
// 4 billion multiplications should be enough for anyone
// and this check prevents overflow in 1<<lg_n below.
return Err(NovaError::ProofVerifyError);
}
if n != (1 << lg_n) {
return Err(NovaError::ProofVerifyError);
}

let mut challenges = Vec::with_capacity(lg_n);

// Recompute x_k,...,x_1 based on the proof transcript
for (P_L, P_R) in self.P_L_vec.iter().zip(self.P_R_vec.iter()) {
P_L.append_to_transcript(b"L", transcript);
P_R.append_to_transcript(b"R", transcript);

challenges.push(G::Scalar::challenge(b"challenge_r", transcript));
}

// inverses
let mut challenges_inv = challenges.clone();
let prod_all_inv = Self::batch_invert(&mut challenges_inv);

// squares of challenges & inverses
for i in 0..lg_n {
challenges[i] = challenges[i].square();
challenges_inv[i] = challenges_inv[i].square();
}
let challenges_sq = challenges;
let challenges_inv_sq = challenges_inv;

// s values inductively
let mut s = Vec::with_capacity(n);
s.push(prod_all_inv);
for i in 1..n {
let lg_i = (32 - 1 - (i as u32).leading_zeros()) as usize;
let k = 1 << lg_i;
// The challenges are stored in "creation order" as [u_k,...,u_1],
// so u_{lg(i)+1} = is indexed by (lg_n-1) - lg_i
let u_lg_i_sq = challenges_sq[(lg_n - 1) - lg_i];
s.push(s[i - k] * u_lg_i_sq);
}

Ok((challenges_sq, challenges_inv_sq, s))
}

/// verify inner product argument
pub fn verify(
&self,
Expand All @@ -686,41 +744,62 @@ where
// Scaling to be compatible with Bulletproofs figure 1
let chal = G::Scalar::challenge(b"r", transcript); // sample a random challenge for scaling commitment
let gens_y = gens_y.scale(&chal);
let mut P = U.comm_x_vec + U.comm_y * chal;
let P = U.comm_x_vec + U.comm_y * chal;

let mut gens = gens.clone();
let mut a_vec = U.a_vec.clone();
let a_vec = U.a_vec.clone();

// Step 1 in Hyrax's figure 7.
for i in 0..self.P_L_vec.len() {
let P_L = self.P_L_vec[i].decompress().unwrap();
let P_R = self.P_R_vec[i].decompress().unwrap();
// calculate all the exponent challenges (s) and inverses at once
let (mut u_sq, mut u_inv_sq, s) = self.verification_scalars(n, transcript)?;

let (P_prime, a_vec_prime, gens_prime) =
Self::bullet_reduce_verifier(&P, &P_L, &P_R, &a_vec, &gens, transcript)?;
// do all the exponentiations at once (Hyrax, Fig. 7, step 4, all rounds)
let g_hat = G::vartime_multiscalar_mul(&s, &gens.get_gens());
let a = inner_product(&a_vec[..], &s[..]);

P = P_prime;
a_vec = a_vec_prime;
gens = gens_prime;
}
let mut Ls: Vec<G::PreprocessedGroupElement> = self
.P_L_vec
.iter()
.map(|p| p.decompress().unwrap().reinterpret_as_generator())
.collect();
let mut Rs: Vec<G::PreprocessedGroupElement> = self
.P_R_vec
.iter()
.map(|p| p.decompress().unwrap().reinterpret_as_generator())
.collect();

Ls.append(&mut Rs);
Ls.push(P.reinterpret_as_generator());

u_sq.append(&mut u_inv_sq);
u_sq.push(G::Scalar::one());

// Step 3 in Hyrax's Figure 7
let P_comm = G::vartime_multiscalar_mul(&u_sq, &Ls[..]);

// Step 3 in Hyrax's Figure 8
self.beta.append_to_transcript(b"beta", transcript);
self.delta.append_to_transcript(b"delta", transcript);

let chal = G::Scalar::challenge(b"chal_z", transcript);

let P_plus_beta = P * chal + self.beta.decompress().unwrap();
let P_plus_beta_to_a = P_plus_beta * a_vec[0];
let left_hand_side = P_plus_beta_to_a + self.delta.decompress().unwrap();

let g_hat = CE::<G>::commit(&gens, &[G::Scalar::one()], &G::Scalar::zero());
let g_to_a = CE::<G>::commit(&gens_y, &a_vec, &G::Scalar::zero()); // g^a*h^0 = g^a
let h_to_z2 = CE::<G>::commit(&gens_y, &[G::Scalar::zero()], &self.z_2); // g^0 * h^z2 = h^z2
// Step 5 in Hyrax's Figure 8
// P^(chal*a) * beta^a * delta^1
let left_hand_side = G::vartime_multiscalar_mul(
&[(chal * a), a, G::Scalar::one()],
&[
P_comm.preprocessed(),
self.beta.decompress().unwrap().reinterpret_as_generator(),
self.delta.decompress().unwrap().reinterpret_as_generator(),
],
);

let g_hat_plus_g_to_a = g_hat + g_to_a;
let val_to_z1 = g_hat_plus_g_to_a * self.z_1;
let right_hand_side = val_to_z1 + h_to_z2;
// g_hat^z1 * g^(a*z1) * h^z2
let right_hand_side = G::vartime_multiscalar_mul(
&[self.z_1, (self.z_1 * a), self.z_2],
&[
g_hat.preprocessed(),
gens_y.get_gens()[0].clone(),
gens_y.get_blinding_gen(),
],
);

if left_hand_side == right_hand_side {
Ok(())
Expand Down
7 changes: 2 additions & 5 deletions src/spartan/direct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ mod tests {

// verify the SNARK
let z_out = circuit.output(&z_0);
let io = z_0.into_iter().chain(z_out.into_iter()).collect::<Vec<_>>();
let io = z_0.into_iter().chain(z_out).collect::<Vec<_>>();
let res = snark.cap_verify(&vk, &io, &com_v);
assert!(res.is_ok());
}
Expand Down Expand Up @@ -542,10 +542,7 @@ mod tests {
let snark = res.unwrap();

// verify the SNARK
let io = input
.into_iter()
.chain(output.clone().into_iter())
.collect::<Vec<_>>();
let io = input.into_iter().chain(output.clone()).collect::<Vec<_>>();
let res = snark.verify(&vk, &io);
assert!(res.is_ok());

Expand Down
Loading