From 1a1569547ca6c692d1e50472ab9622fc7fe7c5cd Mon Sep 17 00:00:00 2001 From: Jess Woods <45497968+jkwoods@users.noreply.github.com> Date: Thu, 7 Dec 2023 14:24:58 -0500 Subject: [PATCH] Verifier Optimization (#9) * make s_i scalars * why not working * batch inv * compilation, tests suspiciously all pass * alright, convinced * comments, clippy * comments, clippy * clean up * clean up --- src/provider/ipa_pc.rs | 213 ++++++++++++++++++++++++++++------------- src/spartan/direct.rs | 7 +- 2 files changed, 148 insertions(+), 72 deletions(-) diff --git a/src/provider/ipa_pc.rs b/src/provider/ipa_pc.rs index af1f3714..5432d902 100644 --- a/src/provider/ipa_pc.rs +++ b/src/provider/ipa_pc.rs @@ -508,49 +508,6 @@ where )) } - fn bullet_reduce_verifier( - P: &Commitment, - P_L: &Commitment, - P_R: &Commitment, - a_vec: &[G::Scalar], - gens: &CommitmentGens, - transcript: &mut Transcript, - ) -> Result< - ( - Commitment, // P' - Vec, // a_vec' - CommitmentGens, // gens' - ), - NovaError, - > { - let n = a_vec.len(); - - P_L.append_to_transcript(b"L", transcript); - P_R.append_to_transcript(b"R", transcript); - - let chal = G::Scalar::challenge(b"challenge_r", transcript); - - // println!("Challenge in bullet_reduce_verifier {:?}", chal); - - let chal_square = chal * chal; - let chal_inverse = chal.invert().unwrap(); - let chal_inverse_square = chal_inverse * chal_inverse; - - // This takes care of splitting them in half and multiplying left half - // by chal_inverse and right half by chal - let gens_prime = gens.fold(&chal_inverse, &chal); - - let a_vec_prime = a_vec[0..n / 2] - .par_iter() - .zip(a_vec[n / 2..n].par_iter()) - .map(|(a_L, a_R)| *a_L * chal_inverse + chal * *a_R) - .collect::>(); - - let P_prime = *P_L * chal_square + *P + *P_R * chal_inverse_square; - - Ok((P_prime, a_vec_prime, gens_prime)) - } - /// prover inner product argument pub fn prove( gens: &CommitmentGens, @@ -661,6 +618,107 @@ where }) } + // from Spartan, notably without the zeroizing buffer + fn batch_invert(inputs: &mut [G::Scalar]) -> G::Scalar { + // This code is essentially identical to the FieldElement + // implementation, and is documented there. Unfortunately, + // it's not easy to write it generically, since here we want + // to use `UnpackedScalar`s internally, and `Scalar`s + // externally, but there's no corresponding distinction for + // field elements. + + let n = inputs.len(); + let one = G::Scalar::one(); + + // Place scratch storage in a Zeroizing wrapper to wipe it when + // we pass out of scope. + let mut scratch = vec![one; n]; + //let mut scratch = Zeroizing::new(scratch_vec); + + // Keep an accumulator of all of the previous products + let mut acc = G::Scalar::one(); + + // Pass through the input vector, recording the previous + // products in the scratch space + for (input, scratch) in inputs.iter().zip(scratch.iter_mut()) { + *scratch = acc; + + acc *= input; + } + + // acc is nonzero iff all inputs are nonzero + debug_assert!(acc != G::Scalar::zero()); + + // Compute the inverse of all products + acc = acc.invert().unwrap(); + + // We need to return the product of all inverses later + let ret = acc; + + // Pass through the vector backwards to compute the inverses + // in place + for (input, scratch) in inputs.iter_mut().rev().zip(scratch.iter().rev()) { + let tmp = acc * *input; + *input = acc * scratch; + acc = tmp; + } + + ret + } + + // copied almost directly from the Spartan method, with some type massaging + fn verification_scalars( + &self, + n: usize, + transcript: &mut Transcript, + ) -> Result<(Vec, Vec, Vec), NovaError> { + let lg_n = self.P_L_vec.len(); + if lg_n >= 32 { + // 4 billion multiplications should be enough for anyone + // and this check prevents overflow in 1< = self + .P_L_vec + .iter() + .map(|p| p.decompress().unwrap().reinterpret_as_generator()) + .collect(); + let mut Rs: Vec = self + .P_R_vec + .iter() + .map(|p| p.decompress().unwrap().reinterpret_as_generator()) + .collect(); + + Ls.append(&mut Rs); + Ls.push(P.reinterpret_as_generator()); + + u_sq.append(&mut u_inv_sq); + u_sq.push(G::Scalar::one()); - // Step 3 in Hyrax's Figure 7 + let P_comm = G::vartime_multiscalar_mul(&u_sq, &Ls[..]); + + // Step 3 in Hyrax's Figure 8 self.beta.append_to_transcript(b"beta", transcript); self.delta.append_to_transcript(b"delta", transcript); let chal = G::Scalar::challenge(b"chal_z", transcript); - let P_plus_beta = P * chal + self.beta.decompress().unwrap(); - let P_plus_beta_to_a = P_plus_beta * a_vec[0]; - let left_hand_side = P_plus_beta_to_a + self.delta.decompress().unwrap(); - - let g_hat = CE::::commit(&gens, &[G::Scalar::one()], &G::Scalar::zero()); - let g_to_a = CE::::commit(&gens_y, &a_vec, &G::Scalar::zero()); // g^a*h^0 = g^a - let h_to_z2 = CE::::commit(&gens_y, &[G::Scalar::zero()], &self.z_2); // g^0 * h^z2 = h^z2 + // Step 5 in Hyrax's Figure 8 + // P^(chal*a) * beta^a * delta^1 + let left_hand_side = G::vartime_multiscalar_mul( + &[(chal * a), a, G::Scalar::one()], + &[ + P_comm.preprocessed(), + self.beta.decompress().unwrap().reinterpret_as_generator(), + self.delta.decompress().unwrap().reinterpret_as_generator(), + ], + ); - let g_hat_plus_g_to_a = g_hat + g_to_a; - let val_to_z1 = g_hat_plus_g_to_a * self.z_1; - let right_hand_side = val_to_z1 + h_to_z2; + // g_hat^z1 * g^(a*z1) * h^z2 + let right_hand_side = G::vartime_multiscalar_mul( + &[self.z_1, (self.z_1 * a), self.z_2], + &[ + g_hat.preprocessed(), + gens_y.get_gens()[0].clone(), + gens_y.get_blinding_gen(), + ], + ); if left_hand_side == right_hand_side { Ok(()) diff --git a/src/spartan/direct.rs b/src/spartan/direct.rs index b6778af9..a5e11e6c 100644 --- a/src/spartan/direct.rs +++ b/src/spartan/direct.rs @@ -467,7 +467,7 @@ mod tests { // verify the SNARK let z_out = circuit.output(&z_0); - let io = z_0.into_iter().chain(z_out.into_iter()).collect::>(); + let io = z_0.into_iter().chain(z_out).collect::>(); let res = snark.cap_verify(&vk, &io, &com_v); assert!(res.is_ok()); } @@ -542,10 +542,7 @@ mod tests { let snark = res.unwrap(); // verify the SNARK - let io = input - .into_iter() - .chain(output.clone().into_iter()) - .collect::>(); + let io = input.into_iter().chain(output.clone()).collect::>(); let res = snark.verify(&vk, &io); assert!(res.is_ok());